Hey folks, in Torch-MLIR we have lowerings into Linalg-on-Tensors, TOSA, and MHLO, so I thought that maybe I could provide some perspective here. Sorry for the long message.
Torch-MLIR’s 3 backends
Linalg-on-Tensors
Linalg-on-Tensors (which is really linalg itself + tensor
, arith
, math
, etc.) was the first backend we added, and it is still is the only one that seems to have a principled and complete “abstraction layer” that it sits at in the presence of dynamic shapes. The things I think really contribute to the solidity and coherence are:
- it takes a hard stance on requiring ranked tensors, which allows all shape stuff to be managed as variadic
index
operands (which plays super well with CSE and other optimizations to prove relationships between shapes). - It consistently permits any dimension size to be dynamic.
- it takes a “hard undefined behavior” stance on shape mismatches, making it clear whose responsibility it is to actually do shape checking if a safe programming model is desired
- “dynamic size 1 broadcasting” is blanket disallowed (as I said here this is actually not a big problem in practice for us)
- The consistent
outs
handling (“destination passing style”) makes the whole system highly coherent and has “no surprises”. - It plays well with “arith” and other dialects and doesn’t “force everything to be a tensor”, leading to a natural division of responsibility that avoids losing information, such as that an integer is a “host side” scalar integer.
Note that dynamic shapes fundamentally changes the layering of ML compilers – it takes what was roughly one abstraction layer and reveals subtleties needed for proper compiler layering, such as modeling of reshapes needing to be more “symbolic” like tensor.collapse_shape
vs just “an op that takes the new shape”, or needing to reason about what happens when shapes mismatch (and at what point in the compiler do you ensure those are either guarded or “yolo assume they are not going to happen”). Both TOSA and MHLO have various levels of incompleteness/unprincipledness here, with MHLO functionally supporting more dynamic shape stuff than TOSA.
TOSA
TOSA was the second backend we added, and remains preferred by many users (especially “hardware” or “hardware-adjacent” folks):
- It is tied to a spec with a really clear “ISA-like” expository style that resonates with a lot of folks
- The coarse-grained named-op approach is a good match for the many compilers that are designed that way
- It has really good support for quantization / integer data types.
- It has clear versioning/stability guarantees on the op semantics.
- It is extremely solid with static shapes (and many of its users only care about static shapes, so that’s fine)
MHLO
MHLO is the third backend we added, and it offers a reasonable blend of the other two
- it is a coarse-grained named-op approach
- it has a pretty clear spec for most of the ops (with a bit of mental translation and hoping that MHLO is the same as HLO): Operation Semantics | XLA | TensorFlow
- it functionally supports dynamic shapes (though not as coherent and consistent as Linalg-on-Tensors, and the dynamic shape support falls outside the wonderful HLO docs above)
- It appears to be pretty tied to HLO (which is highly mature) so most of the op surface area doesn’t change too much
- It has a different set of principles than TOSA which tend to make it more expressive at the cost of having a larger abstraction gap from hardware. For example, TOSA limits (for highly considered reasons) the number of dimensions that certain operators can handle to 1D-4D, when from a purely algebraic perspective there isn’t a good reason to not be more general. As was brought up in this thread, the handling of more general forms of reduction and scatter also falls into MHLO nicely while TOSA’s principles tend bias it away from that.
Major design axes
Based on this experience, I would suggest that the following are the major axes
-
Dynamic shapes support
- Unranked
- Known rank, but with arbitrarily dynamic dimension sizes. (and putting thought into how to layer the “dynamic size 1 broadcasting” situation and semantics in case of shape mismatches)
- Static shapes only
-
DType support: do you use signless integer types or non-signless? TOSA and MHLO use non-signless integer types, and both are super sketchily (even buggily) implemented in this regard with various conversions assuming “signless == signed” and such. And what about type promotion? Or dynamic dtype?
-
do you use a “combinator”/“payload-carrying” approach like linalg? That has a lot of representational compactness benefits, but TBD if it makes sense at all abstraction levels. And then you have to define what goes in the payload (arith?). This also ties into how much and what kind of “named ops” you want.
-
Do you use Destination Passing Style (“outs”)?
-
Stability guarantees / spec / stable binary format / project integration story.
-
Scope of ops: do you only allow certain kinds of structured ops like linalg? or certain “sufficiently simple to efficiently map to hardware” ops like TOSA?
-
Do you “force everything to be a tensor”, even that
i1
that you’re branching on? -
Do you have your own tensor type? Or do you reuse the builtin tensor type?
What would work really well for Torch-MLIR:
If you want to ask me for my personal opinion, here is what I want to target from Torch-MLIR:
- (Strong opinion): Known ranks, arbitrary dynamic dimension sizes. Dynamic size-1 broadcasting and shape mismatches are UB.
- Rationale: Torch-MLIR already needs to handle all the shape errors itself anyway (the exact error semantics and error messages are very delicately related to the frontend op semantics). We actually need to at least prove the rank of ops for to properly infer element types (sad story, but reality). Also, the direction I’m seeing on the frontend side across the ecosystem seems to be heading towards directly producing known rank, dynamic dimension-sized code directly. I really struggle to think of a frontend problem that would require unranked calculations but not dynamic dtypes (for example, doing codegen for an isolated, general “elementwise add” op in PyTorch you would need to multiversion across both shapes and dtypes). And I really don’t think we want dynamic dtypes.
- (Strong opinion): Signless with explicit “extend”/“trunc” ops
- Rationale: Same as “arith” and LLVM. Lowering from frontends is a good place to resolve this out from the types (where it tends to start in the actual user code) into the ops.
- (Weak opinion). Use the combinator approach with an appropriate orthogonal “scalar opset” (could be in same dialect)
- Rationale: The combinator approach has proven to not really be a big problem in the linalg backend, and even allows some things to be done super nicely. E.g. createLinalgPayloadCalculationForElementwiseOp really cleanly handles unary, binary, ternary ops including ones with random additional “scalar” modifiers and all the dtype conversion/promotion semantics.
- (unsure). My gut is that it’s superfluous at the Torch-MLIR backend level of abstraction, but there’s a certain elegance and consistency to it that I really like. Note: We usually need to reify the output shape (which is what you need to materialize the outs init tensor) in the case of dynamic shapes anyway (to do the error checks), so in general the init tensors aren’t too much of a burden. But frontend direction seems to be leaning towards frontends somehow ensuring that we (Torch-MLIR) won’t need to emit error checks ourselves
- (Moderately strong opinion): This is something that I don’t think any of the three backends do “awesome” right now. What I want as a frontend is a handful of files I can drop into my project (FooOps.td, FooOps.cpp, etc.) which I can update “at my leisure” (say ~3 months) and which allow me to produce the Foo stable binary artifact. I think all three backends can be evolved towards this from different angles (even linalg). However, it’s not clear if this is at odds with any other requirements like enabling certain transformations to be written on the IR.
- (Moderately strong opinion): It should “just work” for all the ops that the frontend naturally has, no matter how painful it is to support, even if this leads to various internal lowerings/layerings further down the stack (ideally this can be made pretty composable/progressive so not ultimately that painful).
- Rationale: I don’t want it to be “Torch-MLIR’s job” to decide how to classify, implement, orthogonalize, or layer the efficient, hardware-specific lowering of sort, fft, topk, scatter-with-repeated-indices, qr decomposition, cumsum, embedding bag, “things with >4 dimensions”, “things with data-dependent dimension-sizes” (nonzero, unique), quantized-softmax-without-messing-up-the-final-fusion, etc. There needs to be a really easy path for us to lower all of these ops into something that doesn’t lose information, even if that means there is “more work to do” lower in the stack where more target information is available.
- (Moderately strong opinion): Don’t “force everything to be a tensor”. You already need “true scalars” for my strong opinion above on the handling of dynamic shapes. in Torch-MLIR, it would be information loss to lower
!torch.bool
totensor<i1>
, since it would require “raising” later to recognize that it is a “host scalar” rather than a “device scalar”. - (weak opinion) Have your own tensor type.
- Rationale: We’ve done this in Torch-MLIR (link) and it paid huge dividends for our core computational data type to be something we fully control (it’s not that much code either). I’ve seen similar stories on the IREE side as well. MLIR’s type conversion infra is a known quantity these days (see my talk) so there’s no reason not to. And in places where this hasn’t been done like Linalg, TOSA, and MHLO I’ve seen it cause problems, though it seems “survivable”. E.g. the “abstract tensor” situation in linalg or off-label use of signless types in TOSA and MHLO. And I’m sure there are numerous latent compiler crashes from feeding
!builtin.tensor<!my.random_type>
to various passes – why even allow it?
- Rationale: We’ve done this in Torch-MLIR (link) and it paid huge dividends for our core computational data type to be something we fully control (it’s not that much code either). I’ve seen similar stories on the IREE side as well. MLIR’s type conversion infra is a known quantity these days (see my talk) so there’s no reason not to. And in places where this hasn’t been done like Linalg, TOSA, and MHLO I’ve seen it cause problems, though it seems “survivable”. E.g. the “abstract tensor” situation in linalg or off-label use of signless types in TOSA and MHLO. And I’m sure there are numerous latent compiler crashes from feeding