[RFC] Sparse tensor support in torch-mlir

Sparse tensors can offer qualitative performance differences to some algorithms, including many in machine learning. MLIR has growing support for them, and there are at least a couple competing versions of sparse support for PyTorch. I’d like to pitch a path towards sparse tensor support in torch-mlir.

Sparse tensors at the PyTorch level

PyTorch itself provides a beta API for sparse tensors which supports a handful of linear algebra operations. Currently, that API is locked to two data layouts: COO (for arbitrary-dimensional sparse tensors), and CSR (for 2-dimensional ones). Internally, these are partly implemented via the layout attribute and partly via subclassing the tensor implementation (COO, CSR). I’ll re-iterate that this is a beta API and subject to change.

Alternatively, sparse tensors can be implemented at the user level, for instance be holding dense native tensors for indices and values in a user-defined class and then supplying custom operators for sparse math. torch_sparse is one such example. While not necessarily that relevant, I’ll point out that this was also more or less the approach that the COMET team took in their research, except at the MLIR level.

Finally, while it’s worth at least noting in passing that there are sometimes direct Python interfaces to sparse vendor libraries, but these are a bit outside of scope here, mostly because they sidestep everything we’re building (and thus neither have the same concerns nor receive many of the benefits of the compiler work in the MLIR ecosystem).

Sparse tensors at the MLIR level

The sparse_tensor dialect is MLIR’s answer to first-class support for sparse operations. It’s important to highlight that the dialect itself doesn’t actually provide a distinct sparse tensor type—sparse tensors are defined as attributes on the existing MLIR tensor type. The dialect also provides conversion and utility operations, and a lot of work also exists in related dialects (e.g, linalg) to support sparse-attributed tensors properly.

Bridging the gap

I’d like to propose opening up sparse tensor support for torch-mlir via extending the existing torch.tensor type with an attribute analogous to that of the sparse_tensor dialect. Concretely, this would mean AnyTorchTensorType would be extended with an optional encoding attribute (better names happily accepted). Normal tensors would be able to omit it, leaving all existing behavior unchanged. For conversion to the torch dialect, the importer would need some work to expand ivalue translation to include sparse tensors, and I believe there will need to be some work to extend tensor literal construction. Lowering to linalg+tensors would involve attaching an appropriate sparse tensor attribute and then constructing tensors with it. Because of MLIR’s existing transparent support, there shouldn’t be a lot of changes to math operations. There may need to be explicit conversion calls inserted. Once sparse layouts are supported, we can expand the set of torch-mlir-supported ops to include sparse-specific ones using the normal path. For backends that do not support sparse tensors, we would probably want to simply reject their existence, either via a legalization pass or just by throwing an exception during lowering.

I have some pieces of this approach under development in a personal branch. Because the proposed approach is (or should be) transparent to existing dense tensors, I’d prefer that some of this work happens in the main branch rather than a fork. This would allow (a) refactoring of some existing code that makes overly strict assumptions or relies on dense-only tensor pieces, and (b) clear, continuous evidence that the additional encoding attributes are indeed transparent to existing dense code.

Alternatives: implied layout via analysis

One other mechanism we could conceivably implement is just relying on conversion routines and dataflow analysis to reconstruct the implied layout of a tensor. This has the advantage of a smaller footprint in the PyTorch-torch dialect conversion surface, since no changes to the tensor type or torchscript translation would be needed. On the downside, we’d need a new analysis pass, and it would need to be integrated into each dialect lowering conversion.

I think the biggest downside to this approach is simply that it needlessly tosses away information only to recreate it later. PyTorch’s approach to sparse tensors is through attributed types. MLIR’s approach to sparse tensors is through attributed types. It seems unwieldy and labor-intensive to have torch-mlir’s approach be to chuck the attributes out and attempt to recreate them from scratch.

Alternatives: parallel datatype

We could always create an entirely separate type for sparse tensors. This would give even more isolation from the existing code, but at the cost of an enormous amount of duplicated work. For instance, sparse tensors (and other “tensor-likes” discussed later) need all the same shape inference machinery that normal tensors do. The data layout doesn’t change the mathematical object. This would mean that whenever a change or fix goes in for dense tensors, a second parallel fix would need to be applied to the sparse type. Oh, and all operations would now need to be explicitly multiply-defined on both types. This approach seems infeasible to me, but I’m willing to be corrected.

Open questions

Exact vs. analogous attribute use

The sparse_tensor dialect has its own semantics defined in the MLIR builtin attributes. PyTorch has a slightly different set of semantics (and may change further). This isn’t a show-stopper, but it does raise the question of whether we should directly use the sparse encoding attribute at the torch-mlir level. I’ve not seen any concrete reason not to, but it feels like a defensive approach to have our own, even if it just directly wraps the underlying one. That way, if something changes later, we don’t have to back out a bunch of assumptions about identity. The downside is obviously some additional work and code to support.

Interaction with other layout options

PyTorch has support for regular sparse tensors via strided layout. MLIR’s sparse_tensor dialect doesn’t, but there’s been conversation on it. There’s no reason the proposed attribute couldn’t be extended to handle it, but it would require some thought as to what it would be lowered to.

A discussion on “tensor-likes”

This is a bit off-topic, but since it hits on some of the topics in this proposal, I thought I’d bring it up here anyways.

PyTorch tensors are not particularly extensible. There is no current or planned support for subclassing and inheritence in torchscript, and tensors are given special treatment and handling in much of PyTorch. For some vendors and users, it’s desirable to create an object that behaves a lot like a tensor, but for the reasons above, isn’t a tensor. I think torch-mlir is actually in a unique position to enable that.

PyTorch provides a mechanism to extend its functionality via custom classes. These end up being represented as opaque managed pointers which we have some support for in torch-mlir. There’s actually nothing that prevents us from generating real tensors from these objects. That would actually open up a whole new path for developers to allow tensor pseudo-inheritence by crafting an opaque custom class and providing a torch translation.

For my personal agenda, the use case is fairly obvious: providing a tensor-like that supports arbitrary TACO-style sparse encoding would expose the full power of the MLIR sparse tensor infrastructure to PyTorch users (instead of just the PyTorch beta COO and CSR layouts).

But this also might provides something like type extensibility for torch-mlir for platform-specific tensor representations. For instance, some hardware provides elaborate compression and storage layouts, which would now have a less-terrifying path to support.

1 Like

Hi Bob,

I personally have not been convinced that the “add a tensor ‘encoding’” just transparently works at the levels of abstraction that Torch-MLIR is responsible for. There are lots of transformations that take an op and create multiple intermediate SSA values, each needing a type. My current understanding is that there is no general way to know what is the right sparsity encoding for intermediate values created by such transformations.

For example, when decomposing a Softmax, we need to compute the types for the intermediate values created by the decomposition. I don’t see any way that we can generally adjust that code to “transparently” work for any tensor encoding.

For that reason, I would suggest we explore the “parallel world with a custom data type and ops” – this is similar in spirit to how the QPyTorch support is trending – we let users bring their custom ops and data types, and have Torch-MLIR be extensible enough to support that (we definitely need to improve this UX – it’s on my todo list, and this is a valuable use case).

PyTorch’s experimental support for this is a limited set of supported ops and encoding combinations – it’s not really Torch-MLIR’s responsibility to “innovate” in the core concepts and abstractions for sparse tensors. If upstream is building a specific thing, we should not aim to build a generic thing ahead of their efforts.

These are conceptually different though. For MLIR, the encoding is basically a key that triggers a particular codegen algorithm. For PyTorch, because of eager execution, tensors have to carry this information so that eager kernels can know what to do. Though superficially similar, the reasons are very different, and I would not use this coincidence as a guide for what everything in between those two abstraction layers should look like.

Not sure if this is what you want, but there actually is a way to subclass Tensor: GitHub - albanD/subclass_zoo

I appreciate the feedback, Sean.

I personally have not been convinced that the “add a tensor ‘encoding’” just transparently works at the levels of abstraction that Torch-MLIR is responsible for. … My current understanding is that there is no general way to know what is the right sparsity encoding for intermediate values created by such transformations. … I don’t see any way that we can generally adjust that code to “transparently” work for any tensor encoding.

Let me rephrase a bit: adding a tensor encoding should work transparently for dense tensors. I.e., when adding an encoding attribute, we wouldn’t need to go through and immediately rewrite everything in torch-mlir just because the attribute exists. (I’m lying slightly—we would need to add a mechanism that makes it illegal to pass non-null-encoded tensors to components that have not explicitly registered support—but we only need to do it once and it’s orders of magnitude easier than a breaking change to the torch.tensor type.)

I agree with you that it’s not straightforward to have a general rule for transforming sparse encoding attributes. But this is true regardless of the approach we take. In the example above, I’m not sure having a parallel sparse type would improve anything. All it would do is force us to reimplement that same op for a separate type and then go solve the same problem anyways.

Let’s take your softmax example. With an encoded tensor, there’s two possible default cases here: (1) prohibit it using the mechanism I mentioned above, which basically just means that passing an encoded tensor to any op not explicitly marked for encoded tensor support would be illegal; or (2) allow ops to fall back to a default encoding. I don’t think there’s a ruleset for (2) that works in all cases, which makes me squeamish about trying it. However, there’s definitely a small set of general rules that encompass a huge number of ops. This is very similar to shape inference. And that leads us back to (1): for each op that you want to support under encoded tensors, you need a bit of extra information to decide how the encodings are transformed. Now let’s assume instead that we have a custom type instead. Well, first we can no longer use the implementation above at all, so we’ll need a new one that behaves exactly the same. In addition, we’ll still need the extra mechanism for manipulating the encoding—which is what we were going to do in the non-custom case anyways. The end result is both approaches need the encoding transfer mechanism, but the custom type also needs a copy of every existing op we want to support.

For that reason, I would suggest we explore the “parallel world with a custom data type and ops” – this is similar in spirit to how the QPyTorch support is trending…

As I understand it, QPyTorch is a bit unique since at the PyTorch level, it’s entirely emulated, so the torch tensor type can be reused which avoids having to re-support ops. The custom data type that QPyTorch needs would not be a replacement for the torch tensor type. I don’t think the same approach works with tensor-like objects.

we let users bring their custom ops and data types, and have Torch-MLIR be extensible enough to support that … it’s not really Torch-MLIR’s responsibility to “innovate” in the core concepts and abstractions for sparse tensors.

I hear what you’re saying. I think my view here is that requiring a bespoke type is a non-starter w.r.t. extensibility. For custom ops, the surface area that needs to be touched is small—you only need to handle what you’re putting in. For custom types, the surface area is proportional to the number of supported ops and it creates a maintenance headache because you have to duplicate anything that is fixed or changed in those ops. The attribute approach I’m pitching here is an alternative where the surface area is still proportional to the number of ops, but without the overlap and duplication—you only need to add what’s different. In this case, just the transitive effects on encoding. I’m not taking a stance yet on whether this is a general approach to type extensibility, but I think it makes sense in this case. (This is definitely a tangent into that UX domain you’re talking about).

These are conceptually different though…

I don’t necessarily disagree, but I’ll point out that there’s another similarity here that I feel is important: the correctness of the op semantics aren’t changed by layout (insert a few caveats here…). This separation of function means that layering encoding semantics on top of the existing type is an attractive starting point. I’m not convinced it’s entirely coincidental that both are implemented this way.

(subclassing)

Yep, the __torch_dispatch__ method is viable for subclassing at the PyTorch level, but it’s incompatible with Torchscript. Inheritence is explicitly unsupported (ref). You can sort of do it for nn.Modules, but not for tensors, at least as far as I understand it.

Thanks for pushing on this – this is an important feature and I want to get to the bottom of what we can provide to users that need it.

To move the discussion forward, I think it would be concretely useful to present some IR dumps of the scripted/trace’d IR of a few representative programs using sparsity, and discuss what we want to transform them into for the backend layer.

I want us to think about sparse support from a customer-centric perspective. We know that users are already using custom libraries or a restricted set of operators to write their sparse PyTorch programs, so I think for us the MVP is to meet them where they are and show value on the real programs there. I really don’t want us to be in the game of “innovating” and designing a bespoke Torch-MLIR sparse representation that aims for a level of generality beyond what upstream PyTorch supports. (take the example of !torch.LinearParams for the quantization case - not a great representation, but we import it and are letting users lower it – we aren’t in the business of “innovating” in the quantization design space). Improvements to generality should probably be pushed upstream or built downstream from Torch-MLIR.

To give some context on my objections, making sure that the compiler is stable and predictable is higher on my list of priorities than code duplication. My current understanding is that the audit trail is just too long for making sure that the compiler doesn’t drop the encoding or randomly crash or do the wrong thing in the presence of the encoding. I really don’t know what the abstractions would look like that prevent the encoding from sneaking into places that it shouldn’t, or making sure that it can be handled by the places it gets to. Any prior art on this would be useful. I.e. codebases with large numbers of patterns and transformations that preserve sparsity “encoding” attributes throughout.

Thanks for the RFC!

I am thrilled to see many more projects that treat sparsity as a property emerging! Also, a timely discussion, since my team is currently exploring something very similar, namely sparse tensor support In JAX using MLIR as backend for the actual “sparsification”.

When the MLIR sparse compiler project started, we had a lot of discussions on whether sparse tensor types should be a new, completely separate type, or an encoding on the existing (dense) tensor types. Both approaches have obvious advantages and obvious drawbacks. We decided on the latter, and I can tell from experience that this turned out to be a very good decision. Once we had the sparse compiler path working from Linalg on dense/sparse tensor types, adding a new front-end, such as JAX, was a simple as putting the sparse encoding on tensor types on the generated IR, and relying on the already implemented lowering and rewriting from higher IR to Linalg while preserving these types. In most cases, almost no further code modifications were necessary. Had we gone the path of adding a completely new type, all these “dense” passes would have had to be duplicated just to deal with new sparse types.

Also note that although in general, sparse compilation is moving from specific DSLs like tensor index expressions towards more general array languages (see e.g. this paper), still, we should differentiate between, what I call, “closed world” sparse compilation, where only a restricted set of operations is allowed to be sparsified when operand types are marked sparse and the much more ambitious “open world” sparse compilation, where sparse types truly have become first class citizens. Our current exploration focuses on the “closed world” approach, where a function can only be “sparsified” for a restricted set of primitives (but much more than obviously the standard + and * usually considered). This more restricted approach avoids a lot of the problems Sean is alluding to above since the compiler can guarantee that it knows what to do for all operations when operands become sparse.

1 Like

I think there are a few differences for Torch-MLIR as compared to JAX – technically speaking, JAX is ahead of PyTorch w.r.t. a lean, mean, orthogonalized set of primitives that lower through a very direct path. In Torch-MLIR we have a much more complex set of transformations we have to do to get things into the form that backends want. I expect that as PyTorch evolves, we will see an emerging frontend story with fewer layers and complexities, where sparsity can be treated much more consistently.

To give an example of my concerns with the “audit trail” problem, consider this MHLO transformation picked somewhat at random: mlir-hlo/group_reduction_dimensions.cc at master · tensorflow/mlir-hlo · GitHub
JAX lowers through MHLO, and I’m pretty sure that sparsity through the ‘encoding’ on the tensor type requires special handling in this pass, but I don’t see any special handling. We have a LOT of transformations of that kind in Torch-MLIR. I’m curious if there has been any effort to audit all the different MHLO and linalg transformations to ensure that the sparse tensor encoding is respected (or compilation safely terminates).

I suppose my “closed world” view also transcends beyond restricting the set of “sparsifiable” primitives to restricting the paths taken through the compiler code with an initially rather small set of compiler settings and then gradually adding more and more optimization passes while continuously performing proper audits of potentially new paths. In addition, since the IR of many operations is strongly typed, we have an extra safety that an encoding cannot easily be dropped without noticing. For places where this safety does not apply, the compiler automatically inserts conversion between dense and sparse formats. The latter guarantees correct semantics, but may perform very poorly by materializing sparse tensors into their dense counterparts, something that ultimately should be discovered by proper performance analysis.

Lastly, allow me a slightly controversial statement, that certainly does not apply to compiler design in general. But since introducing the concept that sparsity is a property, not a tedious implementation detail, is still a relatively new concept for MLIR, we want to avoid perfection becoming the enemy of progress, and sometimes discover things “as we go”. Of course, we can afford this luxury only because sparse compilation is not in production yet, but a research project with a very enthusiastic user base that is willing to help out when stuff breaks. But hopefully, we will reach a point soon where our design has sufficiently matured to perform all the exhaustive audits you would like to see happen before going into full production.

I currently see two distinct use cases for “sparsity as a property of a tensor”:

  1. For ABI-visible tensors, it specifies the calling convention that the runtime will use for accepting/returning the tensor. This has mandatory-for-correctness implications on the compiler and runtime.

  2. For non-ABI-visible tensors, it functions as something like a “optimization hint” in typical languages (like an “inline hint” or the register keyword in C, but implying a much more complex set of transformations). At some low-enough layer in the stack, presumably those “hints” become the final decisions of the compiler though.

Just as with inline hints in traditional compilers, all user-specified sparsity in case 2. are subject to being overridden by internal compiler heuristics or ignored for other reasons. For example, consider this IR:

%0 = my_tensor_dialect.add %arg0, %arg1
: tensor<?x?xf32, #CSR>, tensor<?x?xf32, #CSR> -> tensor<?x?xf32, #CSR>
%1 = my_tensor_dialect.tanh %0
: tensor<?x?xf32, #CSR> -> tensor<?x?xf32, #CSC>
// No other uses of %0
// %1 is used by an op that is much more efficient in #CSR

Obviously, we will want to fuse these two operations together, which implies somehow “ignoring” the encoding on %0. And similarly, if %1 is consumed by a latter (unfused) operation for which #CSR is a better encoding, then we would ideally not even bother converting to #CSC for %1 when %arg0 and %arg1 are already in #CSR.

Anyhow, from a user’s perspective, it is extremely tedious in any real codebase to have to annotate the sparsity encoding on every single tensor in the program, so there will always be some degree to which the compiler needs to decide on sparsity encodings internally. Imagine the tedium if users were responsible for inline decisions or register allocation for every value in the program.


[edit: this may not be a big deal with the upstream constructors, but in Torch-MLIR we have constructors like BaseTensorType::getWithSizesAndDtype that presumably would propagate encoding)

For me, the biggest blocker on “sparsity as a property” approach is that unfortunately, unlike inline hints or the register keyword where dropping them or accidentally setting or not setting them doesn’t affect the program validity, setting the wrong sparsity ‘encoding’ on a tensor can result in a user-visible compiler crash or miscompile, because last I checked, in MLIR verifiers on types run inside an assert and so will either crash the compiler when they fail, or in no-asserts builds, create miscompiles or even “garbage” compiler results. For example, setting a 2D encoding on a 4D tensor could create out-of-bounds accesses in a pass that tries to read the 4th dimension of the encoding – causing the compiler itself to crash or worse.

To deal with this, we could potentially go with the type-based approach, we could have the constructors run verification and “ignore” invalid encodings. But I still don’t see much advantage to that over using a special op to carry the user hints. If the lower-levels of the compiler stack require the encoding on the type, we could do that as a transformation before passing off to those layers.


I think it would be quite useful to see IR dumps from real programs written in PyTorch using sparse computations today to see what kind of abstractions we are dealing with from the frontend here. @aartbik do you have links to sparse JAX examples / IR we could stare at too?

I just did a small example, and it looks like the encoding isn’t even propagated into the IR as a property on the tensor type:

import torch

indices = torch.tensor([[0, 1], [1, 0], [1, 1]])
values = torch.tensor([3, 4, 5], dtype=torch.float32)


def f(i, v):
    m = torch.sparse_coo_tensor(i, v, (2, 2))
    return torch.mm(m, m)


print(torch.jit.script(f).graph)
print(torch.jit.trace(f, (indices.t(), values)).graph)

results in

# The scripted graph
graph(%i.1 : Tensor,
      %v.1 : Tensor):
  %12 : NoneType = prim::Constant()
  %4 : int = prim::Constant[value=2]()
  %5 : int = prim::Constant[value=3]()
  %11 : int[] = prim::ListConstruct(%4, %5)
  %m.1 : Tensor = aten::sparse_coo_tensor(%i.1, %v.1, %11, %12, %12, %12, %12)
  %19 : Tensor = aten::mm(%m.1, %m.1)
  return (%19)

# The traced graph
graph(%i.1 : Long(2, 3, strides=[1, 2], requires_grad=0, device=cpu),
      %v.1 : Float(3, strides=[1], requires_grad=0, device=cpu)):
.......
  %m : FloatTensor(requires_grad=0, device=cpu) = aten::sparse_coo_tensor(%i, %v, %16, %17, %18, %19, %20)
  %22 : FloatTensor(requires_grad=0, device=cpu) = aten::mm(%m, %m)
.......

Of course! I can sketch some of the progress we are making.

Consider the following JAX definition of Sampled Dense-Dense Matrix Multiplication (SDDMM), a primitive that occurs often in ML.

  def foo(s, x, y):
      return s * (x @ y)

The interesting case occurs when “s” is a sparse sampling matrix, and “x” and “y” are some dense matrices, let’s keep them all 10x10 for simplicity

  result = foo(s_sparse, x_dense, y_dense)

With the MLIR sparse compiler as backend for JAX, we get the following IR at runtime:

 func.func public @foo(%arg0: tensor<10x10xf64, #SparseMatrix>,
                       %arg1: tensor<10x10xf64>,
                       %arg2: tensor<10x10xf64>)
                           -> tensor<10x10xf64, #SparseMatrix> {
    %0 = mhlo.dot_general(%arg1, %arg2)
       : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
    %1 = mhlo.eltwise_multiply(%arg0, %0)
       : (tensor<10x10xf64, #SparseMatrix>, tensor<10x10xf64>)
       -> tensor<10x10xf64, #SparseMatrix>
    return %1 : tensor<10x10xf64, #SparseMatrix>
  }

Note the well-known performance trap in the resulting IR, observed in prof. Kjolstad PhD thesis. If we would first compute the dense matrix multiplication, and then perform the sampling, we would get the wrong asymptotic complexity. The two kernels need to be fused always to ensure only the dot products that contribute to the final result are computed!

Luckily, MLIR does not fall into this trap. The mhlo dialect above is progressively lowerd and optimized into the following Linalg IR.

 func.func public @foo(%arg0: tensor<10x10xf64, #SparseMatrix>,
                       %arg1: tensor<10x10xf64>,
                       %arg2: tensor<10x10xf64>) 
                           -> tensor<10x10xf64, #SparseMatrix> {
    %0 = bufferization.alloc_tensor()  ...
    %1 = linalg.generic  ins(%arg1, %arg2, %arg0 : ...) outs(%0 :...) {
    ^bb0(%arg3: f64, %arg4: f64, %arg5: f64, %arg6: f64):
      %2 = arith.mulf %arg3, %arg4 : f64
      %3 = arith.mulf %arg5, %2 : f64
      %4 = arith.addf %arg6, %3 : f64
      linalg.yield %4 : f64
    } -> tensor<10x10xf64, #SparseMatrix>>
  }
  return %1 : tensor<10x10xf64, #SparseMatrix>
}

which is subsequently lowered to the proper sparse code that only performs the fused kernel for nonzero elements in the sparse sampling matrix, as explained here (and in previous postings on this forum). This composition of optimizations is really what sets sparse compilation apart from composing available library calls, as is done traditionally.

1 Like

Thanks Aart.

If I understand correctly, in the JAX case, you specify that some or all of the inputs to a function are sparse. So what calculates the #SparseMatrix encoding on %1 in this op?

    %1 = mhlo.eltwise_multiply(%arg0, %0)
       : (tensor<10x10xf64, #SparseMatrix>, tensor<10x10xf64>)
       -> tensor<10x10xf64, #SparseMatrix>

Is this propagation done by JAX itself? What if %arg1 is sparse instead – how does JAX know what is the right sparse encoding for %0? And what are the allowable sparse encodings here?

The sparse types are constructed by the (sparse) JAX to MLIR-MHLO “bridge” that builds the MLIR IR for a JAX expression. For now, this bridge uses some heuristics that seem “reasonable”, such as preserving sparsity over conjunction element wise operations, zero preserving math operations, and a few more. Also, right now, we only support the “all-compressed, row-wise” TACO flavored storage scheme, which avoids the need for inserting conversions. Once we have a fully functional system up and running, we want to explore real-life examples, and improve our rule set (we have some ideas on how to do that already, but nothing worked out fully yet; but at that point, we will have the advantage of being able to rapidly evaluate new ideas on a running system).This research-driven approach works well for my team, but perhaps it is less applicable to the direction you want to set for the torch-mlir project. In any case, let’s keep on discussing while both our teams are making progress. This is fun!

Indeed! Your team’s research-driven approach here is quite beneficial for bootstrapping new frontend ideas in this space. I’m really interested to see your findings!

For now, I think Torch-MLIR’s approach is twofold:

  1. Identify the “seed” sparsity information which is the user’s intent, and have it survive to backends. E.g. if the user programming model is just to specify that a weight matrix is sparse COO (as opposed to a “sparse matmul op”), then we need to find a way to have that information survive into IR and reach backends. Pragmatically, we will delegate the responsibility of using this “seed” information to backends. For example, perhaps there can be shared infra at the linalg level to take a sparsity specification for a subset of Value’s and propagate it to all Value’s (possibly with autotuning; I think linalg is a very good place to do this as compared to the frontend).
  2. Work with upstream PyTorch to generalize and enhance the programming model for sparsity and participate in the sparse PyTorch ecosystem. In particular, I want our agenda in Torch-MLIR to be grounded in “real programs” that users are finding of real value today. Practically, this means finding useful sparse programs of at least 20-100 ops, since that size at least is necessary to exhibit the kinds of usability/programming-at-scale concerns that I mentioned above.

Just wanted to update this thread with some spontaneous lunchtime discussions with Aart the other day. One of Aart’s observations regarding sparse frontends was that the sparsity property was easy to accidentally drop without significant programmer care (for example, doing “1 + x” turns all zero’s into non-zeros). The literature also has trended towards more specific sparse frontend annotations and “closed” programming models.

In this light, the PyTorch approach of a patchwork of kernels with a patchwork of supported sparsity layouts is unclear – is this a robust programming model for sparse? Bob, what are your thoughts?

Thanks for the patience in letting me get back to you. Been a busy few weeks.

I appreciate the discussion here, and I’ll try to hit some of the highlights you both have brought up.

When the MLIR sparse compiler project started, we had a lot of discussions on whether sparse tensor types should be a new, completely separate type, or an encoding on the existing (dense) tensor types…

Those discussions informed my pitch for sparse support in torch-mlir. There are differences, to be sure, but a lot of the same points still apply (Aart, Chris’ last point).

… we should differentiate between, what I call, “closed world” sparse compilation, where only a restricted set of operations is allowed to be sparsified when operand types are marked sparse and the much more ambitious “open world” sparse compilation, where sparse types truly have become first class citizens…

I think we’re all on the same page here, and I think it’s double-true for PyTorch integration. PyTorch is very much a closed implementation right now, veering much more closely to the ‘separate type’ approach than to an open world (c.f., the internal APIs for storage access on sparse vs. dense tensors is entirely different).

(audit trail discussion)

I agree with you, Sean, but I think we also have to consider the flipside: let’s assume we implemented a separate type with separate semantics. For every transformation like this, instead of having “special handling”, we’d have to re-implement it completely. That’s a lot more work, a lot more maintenance, and honestly, a lot more chance for things to go wrong. I’m already having nightmares about PRs that fix something in one transform, only for the ‘sparse’ version to get overlooked, causing a divergence in support and nasty bugs to lurk.

In an annotation world, as long as we reject annotations on transforms or operations that are not explicitly tagged as supporting them, then the worst that happens is the sparse path lacks functionality. The dense path carries on with life as usual. Moreover, we have clear error conditions to trace back in order to expand sparse support as necessary.

I completely support your concern about annotations getting dropped, which is why I believe explicit sparse support tagging is viable. Sure, it means that we’ll have to go through the effort of manually enabling support for a lot of things, but it means all of those choices are informed and have a human approval behind them.

(ABI-visible tensors vs. hints)

I don’t believe “hints” should be the default nor the starting point. There’s a lot of applications where sparsity must be enforced as a precondition for executing it in any kind of reasonable timeframe, and for those apps, hints are a non-starter. I’ll draw analogies to auto-parallelizing compilers on vector processors. Even when those pragmas or compiler directives were “hints”, there was a clearly defined process for knowing that it worked, and the feedback loop of developer-compiler-error was sufficient to guarantee behavior before runtime. For PyTorch, I don’t see a good feedback loop that can work that way. There’s way too many layers in between, and the general use case is intended to remove a lot of that.

While you mention that it’s ‘tedious to annotate everything’ (and I agree), the solution at the PyTorch level (and at the level of libraries which support sparsity in conjunction with PyTorch) is to have well-defined semantics for how those annotated types are transformed when put through ops (e.g., this).

I agree that silently dropping attributes is not a path we want to go down.

(JAX example)

SDDMM is a fantastic example, as it is close to one of the driving use cases that originally brought me down this path: neural message passing in GNNs. Avoiding the dense-first operation is non-negotiable for computing this. Meaning: the implementations for GNN libraries (PyG/DGL) must be able to ensure this does not happen.

As far as the equivalent torch-mlir bridge, I’m a bit leery of leaving too much to heuristics out of the gate. The penalty for getting it wrong is often “you may as well not even run it”. That said, because PyTorch has a much more limited implementation of sparse tensors, I believe the starting point is much more amenable to a strict interpretation.

Ultimately, I believe we need something more. For instance, there’s a number of generalized SPMM-like reductions in GNNs which really should be fused and implemented in linalg.generics underneath. For MLIR to support PyTorch-implemented GNNs natively, that means having enough flexibility to do the exact lowering Aart describes with something partially user-defined in the position of mhlo.dot_general.

(torch-mlir approach)

  1. Identify the “seed” sparsity information…
  2. Work with upstream PyTorch to generalize and enhance…

I generally agree with 1., although I think that there is a more restricted starting point which follows the current PyTorch sparse implementation more closely. My hope is that we can start small and narrow, then expand as we are comfortable with the sparse pieces not breaking anything else.

All your points on 2. are spot-on. One of the points of friction today is that there is a bit of a chicken-egg problem: sparse support is immature, so utilization of sparse features is limited. Instead many PyTorch users implement those routines by hand in C++ and plug them in. This is, in the long run, suboptimal. A true sparse compiler can do better, and I’m of the opinion that MLIR’s approach to it makes it the best candidate for taking over that responsibility.

(1+x example)

Totally agree, but this is not just a compiler problem (for the moment, we’ll use “1+x” as a stand-in for the general problem of accidentally densifying tensors; in practice, there’s tricks for this specific example). My feeling on this is that the first goal should be to enable code that doesn’t fall into this trap, and then try to expand support to make it easier to avoid it. PyTorch itself does not have a particularly clear story for this yet, as I understand it, so it’s a bit further down the road, at least in the torch-mlir realm.

Thanks Bob, that clarifies a lot. I think the tl;dr I’m getting from this discussion is that there is a lot of work to do outside Torch-MLIR on how sparse is surfaced to the user and what IR Torch-MLIR even sees at its input. Once we have some clarity there we will have specific concrete requests for what Torch-MLIR needs to do to enable things.

I agree that having all the transformations support both sparse and dense “mostly transparently” would be a better end state than a massive duplicated code path. But I think there is room for a duplicated code path in the beginning to prove out the value-add and user scenarios, and then demonstrate empirically the pain of duplication and the comparative cost of unifying the two paths.

I think it is honestly still a big open question whether that dream of everything magically supporting sparse is even a desirable/feasible/realistic end state – the vibe I’m getting is that there are problems further up the stack w.r.t. programming model, closed-world-ness, etc. that could relegate sparsity to a somewhat more specialized role with more specific dedicated flows that compose with what we already have.

To me this points to programming model questions further up the stack. I don’t think Torch-MLIR should be in the business of providing that kind of “guaranteed sparsity decisions” feedback loop UX, and once that exists at a higher level, how we encode it in Torch-MLIR is a relatively minor technical decision. I’d really like us to focus in this thread on “the user/frontend has made these decisions around sparsity, here is the IR that they give us, and here is the IR we want the backend to receive”, and then focus on the design problem of making the pieces connect. That’s something concrete we can move towards.

Additionally, one of the benefits of Aart’s code is the ability to autotune the choice of sparsity schemes – this is direclty at odds with “guaranteed sparsity layouts” that are surfaced to the user. I think as an ecosystem we are still untangling the programming model and use cases here.

PT is eager though, so this can only be done with custom ops anyways. And once we get into “guaranteed fusion” sorts of discussions, I think we are definitely wanting to start with a more focused incremental path with a focus on programming model.

That gives me an idea – we could use Torch-MLIR here as a building block to help define custom ops at the PyTorch level. E.g. something like this

import torch
import custom_sparse_op_gen

@custom_sparse_op_gen.generate("my_namespace.sddmm", sparsity_spec=...)
def sddmm(s, x, y):
    return s * (x @ y)

def my_gnn(t: torch.Tensor, ...):
  ....
  # call sddmm
  ....

custom_sparse_op_gen.generate would have the following responsibilities:

  1. do sparse codegen according to sparsity spec
  2. register a custom PT op my_namespace.sddmm

We could use Torch-MLIR as a building block here, to lower the computation to linalg, and then at the linalg level use sparsity_spec’s annotations on the inputs to propagate the sparsity encoding throughout the computation (linalg seems like a good level of abstraction to do this). You could build up the “guaranteed sparsity” UX in that linalg-level sparsity propagation layer.

With a suitable custom op story in Torch-MLIR, we could even compose this into larger flows (e.g. if you need whole-program capture to feed this to an ASIC or deploy on mobile or something).

I really don’t think that diving headfirst into “the whole compiler mostly transparently supports sparsity throughout” is the immediate first direction we want to head down. I really think that almost all of the work here is either to be done above or below Torch-MLIR, with Torch-MLIR merely acting as an intermediary for specific requests to propagate certain information to backends. The snippet above, could actually be done with Torch-MLIR unmodified even.