I am trying to export the MLIR file of Llama2 using fx.export_and_import
in torch_mlir
. By default, it seems to export the model’s forward function. Is there a way to specify which function to export, specifically to export the generate function of Llama2 for inference? Any guidance would be greatly appreciated!
This is a limitation in torch itself. It didn’t used to be so, but for some reason, they decided they wanted it this way (i.e. torch.export
has no way to pass a naked function). I personally think it is a design flaw.
In IREE, I wrote this helper some time ago to work around that limitation (among others): iree-turbine/shark_turbine/aot/fx_programs.py at main · iree-org/iree-turbine · GitHub
You are welcome to take/adapt it.
Thank you for your guidance, I will try on this.