There’s really no magic here: the torch dialect has no handling for sparse tensors. That would need to be added and the importers taught to preserve it. The API you used above is what we call the [old] TorchScript importer and is probably a lost cause. When I wrote the FxImporter, I noted that the tracing data structures did claim to represent some sparsity metadata, but without samples or a need, it is just a todo. I don’t know how deep that rabbit hole on the torch side is, either.
References:
- torch-mlir/python/torch_mlir/extras/fx_importer.py at 985e7796a4e4c2b939c4c350047db2473fcdc8f2 · llvm/torch-mlir · GitHub
- pytorch/torch/fx/passes/shape_prop.py at main · pytorch/pytorch · GitHub
Edit: it would appear I misremembered. The TensorMetadata struct has optional quant info but not sparsity. PyTorch may be sparse blind for tracing and that would need to be worked out first. There are other ways to do it but should confirm that first.