Customizing Function Export with fx.export_and_import

I am trying to export the MLIR file of Llama2 using fx.export_and_import in torch_mlir. By default, it seems to export the model’s forward function. Is there a way to specify which function to export, specifically to export the generate function of Llama2 for inference? Any guidance would be greatly appreciated!

This is a limitation in torch itself. It didn’t used to be so, but for some reason, they decided they wanted it this way (i.e. torch.export has no way to pass a naked function). I personally think it is a design flaw.

In IREE, I wrote this helper some time ago to work around that limitation (among others): iree-turbine/shark_turbine/aot/fx_programs.py at main · iree-org/iree-turbine · GitHub

You are welcome to take/adapt it.

Thank you for your guidance, I will try on this.