[RFC] Sparse tensor support in torch-mlir

There’s really no magic here: the torch dialect has no handling for sparse tensors. That would need to be added and the importers taught to preserve it. The API you used above is what we call the [old] TorchScript importer and is probably a lost cause. When I wrote the FxImporter, I noted that the tracing data structures did claim to represent some sparsity metadata, but without samples or a need, it is just a todo. I don’t know how deep that rabbit hole on the torch side is, either.

References:

Edit: it would appear I misremembered. The TensorMetadata struct has optional quant info but not sparsity. PyTorch may be sparse blind for tracing and that would need to be worked out first. There are other ways to do it but should confirm that first.

1 Like

It looks like these folks ran into many of these issues and produced something serviceable: [2304.07613] STen: Productive and Efficient Sparsity in PyTorch

I’m not an expert on this, but it does address my baseline assumption that pytorch’s representation for sparsity is non general and focused on a few hero ops. Also confirms that fx is more amenable.

I just skimmed but their approach of using explicit sparsifiers, if materialized as ops maps in a similar way to how we think about quantization at this level.

In any case, FX is more composable and if you invent something at the torch level, you can likely get it propagated as metadata. The link above shows that we are just consulting known metadata structs and using them to construct types.

2 Likes

Thanks for all your additional input, Stella, much appreciated.

In the style of early days of TORCH-MLIR, I made a quick (and very simplified) sketch of what I think needs to happen. In this picture, the parts in “blue” are already there, and the parts in “red” still need to be done.

For example, torch.sparse, although in beta, is already there, and thus shown in “blue”. Similarly, once we reach MHLO or Linalg with sparse tensor types in MLIR, we have a fully functional pipeline that runs on CPU with some GPU acceleration being added. As for the parts in “red”, there are no sparse torch tensors yet, nor TOSA or StableHLO (although based on our experience adding sparse tensor types to MHLO, we have a StableHLO+ Sparse RFC for that; this is why I still showed MHLO to this picture, even though I realize it is being phased out).

Note that one very nice result of having torch.sparse as part of PyTorch is that it avoids the audit objection of Sean mentioned above, since the extension already takes care of that. All torch.sparse tensor types are a subset of the sparse tensor types supported by MLIR (for example, batch dimensions map to “dense” and the intermediate CSR to “dense”/“compressed”, followed by “dense” dimensions again for the subtensors), so hopefully that mapping is smooth.

As for all the “red” parts above, in the next few weeks, we will explore adding sparsity to torch tensors and lowering this to e.g. Linalg + sparse tensor types with a simple reference compiler to run some end-to-end examples.

1 Like

You’re welcome. I’m a very non visual person, so I’ll need to see it in code. I’d focus on making sure you can get what you want at the FX level and the rest will follow. That is also where most of the tools are for mapping pytorch concepts to IR generally. I’d recommend looking at the sten work even if just to be understand how this part of the stack is put together and can be used.

I can speak with some experience that until you have a good working knowledge of how the FX and dynamo stack work, you’re going to have a tendency to solve things more difficulty at a sub optimal level. We’re on the far side of the long term roadmap now on that front and that approach is basically where new work is being done.

This stuff is also much more amenable to prototyping now than it used to be – relatively small amounts of code can show a full concept. I don’t remember the last time I actually had to write pytorch c++ code to get stuff like this done (actually, that’s a lie – I do remember but just don’t want to :slight_smile: ).

2 Likes

Yeah, this sketch was very merely to organize my own thoughts on this a bit. Thanks for stearing me to FX. We will start some prototyping in our own fork and then touch bases again when we have questions (which I am sure will come up very soon ;-)). As always, I appreciate your guidance on this.

1 Like

Just to give a brief update on this.

First, if we somehow get the sparsity metadata propagated to the torch tensors, the actual lowering to e.g. linalg is relatively straightforward and can probably be done with minor code changes. For example, if we get “sparsity” on the torch.vensor below:

module attributes {torch.debug_module_name = "SparseNet"} {
  func.func @forward(%arg0: !torch.vtensor<[4,4],f32,IS_CSR>) -> !torch.vtensor<[],f32> {
    %none = torch.constant.none
    %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[4,4],f32,IS_CSR>, !torch.none -> !torch.vtensor<[],f32>
    return %0 : !torch.vtensor<[],f32>
  }
}

then we can lower this to the following linalg op by just adding the proper sparsity encoding while lowering (literally just a few lines of code here and there):

module attributes {torch.debug_module_name = "SparseNet"} {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @forward(%arg0: tensor<4x4xf32, #sparse>) -> tensor<f32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<f32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<f32>) -> tensor<f32>
    %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "reduction"]} ins(%arg0 : tensor<4x4xf32, #sparse>) outs(%1 : tensor<f32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = arith.addf %in, %out : f32
      linalg.yield %3 : f32
    } -> tensor<f32>
    return %2 : tensor<f32>
  }
}

Second, Stella is right that it seems PyTorch may have to trace sparsity a bit further than is currently the case. Running the FxImporter chokes on computing the required size (as shapes x strides).

torch._dynamo.exc.InternalTorchDynamoError: Sparse CSR tensors do not have strides

I also made an initial posting in the PyTorch forum to get the discussion going there as well.

Yeah, that’s about what I thought. May want to directly at a couple of the core folks with a tldr if that doesn’t get any attention in a few days.

I don’t want to get your hopes up too much: in my experience Pytorch’s compiler infra is quite a ways from being cognizant of all of the “amazing” things that can be done with eager tensor types, and there is a definite split between some of the “vintage” usage and more modern thought.

Also, in my experience, the road often leads through a combination of new custom kernel/custom tensor type pairing done at the python layer vs in the very “dispatchy” way of the prior generation of solutions. Such things are very cheap in pytorch and typically rather well supported by the tools. I can picture how to do it but I don’t know how to tell you to intersect that with the existing sparse extension.

1 Like

Also, a bit of advice: the PyTorch team is pretty swamped with last mile support for the diaspora of things for full FX and dynamo support. If you want to get to the top of the queue, I’d include a small code snippet using upstream export APIs, the exception, and some pseudo code of what you expect the FX IR to be if this were to work. Small repros with stock setups are worth ten thousand words.

1 Like

Very timely advise! I just posted my question with a more concrete example that uses torch.export.export for building the trace graph (which was the very first “pure upstream” code I could find in your fx importer from torch-mlir).

I’d include the relevant part of the stack trace in the post. I expect that exception is being thrown by the sparse tensor type itself when something is asking it for strides. But it isn’t clear how deep the rabbit hole goes without more analysis on what is making that call and how specialized it is to normal strided tensors.

If I were you, I would chase it a bit further into the core system and try to get it to either “it seems like if this one thing were a little smarter it would work but I don’t know how to do this generally” or “I hacked that call site but then it fails [at the next place]”. Basically, show that you’re not just doing a drive by but are digging in and either hitting something or getting lost and need an evac.

A minor update, I migrated the PyTorch discussion thread to the PyTorch dev forum (since user forum was not the right place). In the meantime, I got some very hacky prototype working that adds layout fields to various FX traced graph components. Next steps are trying to make use of those in torch-mlir and get some pipeline going with the MLIR sparsifier, just to see if this is indeed the right path forward.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(
              self,
              l_x_: "f32[64, 64]:torch.sparse_csr"):   # ADDED!
            # File: biknet.py:27, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)
           
Graph signature: ExportGraphSignature(
  input_specs=[
      InputSpec(
           kind=<InputKind.USER_INPUT: 1>,
           arg=TensorArgument(name='l_x_'),
           target=None,
           layout=torch.sparse_csr)       # ADDED!
  ],
  output_specs=[
     OutputSpec(
         kind=<OutputKind.USER_OUTPUT: 1>,
         arg=TensorArgument(name='sum_1'),
         target=None)
 ])

After some fun hacking at three different repositories at once, I have a prototype end-to-end sparsity propagation working (actually running the code is of course still a “detail” to work out ;-).

To summarize, various changes will be needed in the PyTorch repo to propagate sparsity from something like

class BikNet(torch.nn.Module):
  def __init__(self):
    super(BikNet, self).__init__()
    return
  def forward(self, x):
    return x.sum()
...
biknet(sparse_csr_input)

into the FX traced graph with some sparsity information:

def forward(self, l_x_: "f32[64, 64]:torch.sparse_csr"):
  # File: biknet.py:27, code: return x.sum()
  sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
  return (sum_1,)

Then various changes will be needed in the TORCH-MLIR repo to further propagate this into vtorch types.

module {
  func.func @main(%arg0: !torch.vtensor<[64,64],f32,SPARSE>) -> !torch.vtensor<[],f32> {
    %none = torch.constant.none
    %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[64,64],f32,SPARSE>, !torch.none -> !torch.vtensor<[],f32>
    return %0 : !torch.vtensor<[],f32>
  }
}

and then lower this to an MLIR dialect like StableHLO with sparse tensor types. In this case, I hacked it out for lowering to linalg (since the StableHLO proposal has not been accepted yet).

#map1 = affine_map<(d0, d1) -> ()>
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main(%arg0: tensor<64x64xf32, #sparse>) -> tensor<f32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<f32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<f32>) -> tensor<f32>
    %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "reduction"]} ins(%arg0 : tensor<64x64xf32, #sparse>) outs(%1 : tensor<f32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = arith.addf %in, %out : f32
      linalg.yield %3 : f32
    } -> tensor<f32>
    return %2 : tensor<f32>
  }
}

After which the MLIR sparsifier pipeline can proceed with business as usual.

    %2 = scf.for %arg4 = %c0 to %c64 step %c1 iter_args(%arg5 = %1) -> (f32) {
      %3 = memref.load %arg0[%arg4] : memref<?xindex>
      %4 = arith.addi %arg4, %c1 : index
      %5 = memref.load %arg0[%4] : memref<?xindex>
      %6 = scf.for %arg6 = %3 to %5 step %c1 iter_args(%arg7 = %arg5) -> (f32) {
        %7 = memref.load %arg2[%arg6] : memref<?xf32>
        %8 = arith.addf %7, %arg7 : f32
        scf.yield %8 : f32
      }
      scf.yield %6 : f32
    }

So, overall, still tons of details to work out , but I am pretty confident this approach has a very viable path forward!!!

3 Likes

Thanks to MLIR’s very flexible design, putting an encoding on the vtorch tensors is trivial, and due to the use of proper interfaces, without leaking ANY details of sparsity to torch, we are already able to roundtrip sparsity, as well as verify the consistency of the encoding.

// valid
#CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed)
}>
func.func @main(%arg0: !torch.vtensor<[64,64],f32,#CSR>)
                    -> !torch.vtensor<[64,64],f32,#CSR> {
  return %arg0 : !torch.vtensor<[64,64],f32,#CSR>
}

// invalid
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>

func.func @main(%arg0: !torch.vtensor<[64,64],f32,#SV>)
                    -> !torch.vtensor<[64,64],f32,#SV> {
  return %arg0 : !torch.vtensor<[64,64],f32,#SV>
}

YIELDS:
torch.mlir:7:38: error: dimension-rank mismatch between encoding and tensor shape: 1 != 2
func.func @main(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64],f32,#SV>

After a few PRs, this latest PR actually shows how to import a PyTorch model with potentially sparse arguments as an FX traced graph into torch-mlir. It essentially uses Stella’s importer, but a wrapper that converts sparse arguments to dense tensors, builds the traced graph, and puts an annotation back. This is of course not the desired importer, but can be used for testing until this FX feature request is resolved.

At present, we can take something like

   class MatMulNet(torch.nn.Module):

        def __init__(self):
            super(MatMulNet, self).__init__()

        def forward(self, x, y):
          return torch.matmul(x, y)

m = export_and_import(MatMulNet(), A_coo, B_dense)

and when invoked with a sparse x, convert this into the following SpMM representation

#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
module {
  func.func @main(%arg0: !torch.vtensor<[64,64],f32,#sparse>, 
                  %arg1: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
    %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[64,64],f32,#sparse>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
    return %0 : !torch.vtensor<[64,64],f32>
  }
}

which can now also be further lowered to linalg:

linalg.matmul ins(... : tensor<64x64xf32, #coo>,
                        tensor<64x64xf32>)

I am super excited with this progress!

2 Likes

Excited to see you spreading the sparse love to other ecosystems!

1 Like

Minor update, we now have sufficient machinery in torch-mlir to run a simple PyTorch model “end-to-end” for sparse tensors as input. Take for example, the following code

class MatMulNet(torch.nn.Module):
    
    def __init__(self):
       super(MatMulNet, self).__init__()
    
    def forward(self, x, y):
       return torch.matmul(x, y)

Then we get the same results when running with the normal PyTorch engine vs. torch-mlir execution (operating on the underlying numpy arrays):

    net = MatMulNet()
    a = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 2, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 3],
                      [0, 0, 0, 0, 0, 0, 0, 4],
                      [0, 0, 0, 0, 0, 0, 0, 5]],dtype=torch.float32)
    sparse_input = a.to_sparse_csr()
    res0 = net(a, a)
    res1 = net(sparse_input, a)
    res2 = sparse_jit(net, sparse_input, a)   # uses TORCH-MLIR +sparse

all yield the following numpy data

[[ 1.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  4.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0. 15.]
 [ 0.  0.  0.  0.  0.  0.  0. 20.]
 [ 0.  0.  0.  0.  0.  0.  0. 25.]]
2 Likes