[RFC] Sparse tensor support in torch-mlir

Thanks for all your additional input, Stella, much appreciated.

In the style of early days of TORCH-MLIR, I made a quick (and very simplified) sketch of what I think needs to happen. In this picture, the parts in “blue” are already there, and the parts in “red” still need to be done.

For example, torch.sparse, although in beta, is already there, and thus shown in “blue”. Similarly, once we reach MHLO or Linalg with sparse tensor types in MLIR, we have a fully functional pipeline that runs on CPU with some GPU acceleration being added. As for the parts in “red”, there are no sparse torch tensors yet, nor TOSA or StableHLO (although based on our experience adding sparse tensor types to MHLO, we have a StableHLO+ Sparse RFC for that; this is why I still showed MHLO to this picture, even though I realize it is being phased out).

Note that one very nice result of having torch.sparse as part of PyTorch is that it avoids the audit objection of Sean mentioned above, since the extension already takes care of that. All torch.sparse tensor types are a subset of the sparse tensor types supported by MLIR (for example, batch dimensions map to “dense” and the intermediate CSR to “dense”/“compressed”, followed by “dense” dimensions again for the subtensors), so hopefully that mapping is smooth.

As for all the “red” parts above, in the next few weeks, we will explore adding sparsity to torch tensors and lowering this to e.g. Linalg + sparse tensor types with a simple reference compiler to run some end-to-end examples.

1 Like