[RFC] Sparse tensor support in torch-mlir

After a few PRs, this latest PR actually shows how to import a PyTorch model with potentially sparse arguments as an FX traced graph into torch-mlir. It essentially uses Stella’s importer, but a wrapper that converts sparse arguments to dense tensors, builds the traced graph, and puts an annotation back. This is of course not the desired importer, but can be used for testing until this FX feature request is resolved.

At present, we can take something like

   class MatMulNet(torch.nn.Module):

        def __init__(self):
            super(MatMulNet, self).__init__()

        def forward(self, x, y):
          return torch.matmul(x, y)

m = export_and_import(MatMulNet(), A_coo, B_dense)

and when invoked with a sparse x, convert this into the following SpMM representation

#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
module {
  func.func @main(%arg0: !torch.vtensor<[64,64],f32,#sparse>, 
                  %arg1: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
    %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[64,64],f32,#sparse>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
    return %0 : !torch.vtensor<[64,64],f32>
  }
}

which can now also be further lowered to linalg:

linalg.matmul ins(... : tensor<64x64xf32, #coo>,
                        tensor<64x64xf32>)

I am super excited with this progress!

2 Likes