[RFC] Proposal for a high-level ML dialect in MLIR

Thanks for the detailed writeup @_sean_silva! Couple of questions below:

I could be missing some nuance (CC @herhut @matthias-springer) but I assume this means when lowering from Torch-MLIR we’ll need to make a “best guess” for where the outputs of every operation should go. For operations like scatter or gemm_with_bias_add this output tensor has a natural candidate but simpler operations like reductions or cwise operations don’t – they will require an analysis conceptually similar to bufferization to make this determination. Is my understanding correct?

Is this because you want to store graphs long term or because you don’t want to have to continually adapt Torch-MLIR?

@_sean_silva Thanks for those excellent points. It is very useful to know these design points from the perspective of Torch-MLIR. I, personally, strongly align with several of these.

This is one I would like to debate more on. While it is clear that the translation from signedness in types to signedness in ops should happen at some point in the stack, it is not clear if doing it immediately after the frontend is the right place. At the level of ops that we are proposing here, it seems more prudent to have signedness in the types.

+1 Strongly agree with this. But, even if we strive for this, there might still be some frontend ops that may not map to any of the ops in this dialect. The goal should be to minimize this set of frontend ops as much as possible.

Thanks, Sean, for such a thorough recounting of learnings from the torch side. When we started that, I was hoping that it would be successful in staking out a contrasting design space.

One of the things I keep hearing is the desire for people to be doing higher level algebraic or dataflow-level transformations. The lower we go in the stack, the harder such transformations become (because they have had certain concepts lowered into forms that are comparatively hard to match and make “large” changes to). But conversely, because the funnel is tighter at the low level, if it can be done there, it can save having to do it on N frontends. I see that tension play out basically daily.

Linalg isn’t great for this: all of the explicitness you mention in its design makes it hard on this axis. I wouldn’t say I’ve really “enjoyed” using MHLO for such tasks. The TF dialects get used for this extensively. And on the Jax side, the highest impact transformations are actually happening at the JAXPR level (imo). When I look at the Torch dialect, I see it as sharing the same potential, but given how young it is, I’m not sure how much is being done here (i.e. to my knowledge, no one has written “optimizers” or distribution transforms at the torch dialect level yet).

Are there thoughts on where these tasks live in the layer stack of dialects? Is this something that we should expect a reduction dialect like MHLO/TOSA to be good at, or should we just be driving these workloads to the framework level dialects and accept that we live in a fragmented world and there will be duplication? And that the ability for the framework level to service its users in this way does play in to “how good” it is at doing its job?

I ask because if we factor those requirements out of the common dialects we are discussing, then what is left is purely lowering and integration – and that is a differently constrained problem that may be easier to solve.

3 Likes

@stellaraccident, this is a good summary of our thought process as well. We’re looking at it from a transformation point of view and linalg really gets too low level to be efficient in the search for patterns that would be easier and often trivial on a higher level pass.

Not that linag is bad for transforms, it isn’t. It’s great for many things. But there are two main things that get in the way:

  1. Linalg isn’t complete (by design): A single high-level op can be lowered into a complex set of linalg + something patterns, which makes understanding semantics much harder than just looking at the ops, but also potentially destructive, because we’d have to know what that something is (which can be different across front-ends) and now you have a much larger problem space.
  2. Other passes operate on that something (and potentially linalg): Between lowering and our transformation passes (which may not be contiguous), there can be a number of other passes and cleanups that can destroy the canonical forms, making it harder (or impossible) for us to find them.

All that is well know and one of the core reason for MLIR’s design, and for that reason I believe we need a higher level dialect to operate on. We can (and probably will) have multiple passes on that high-level dialect, and on linalg, etc, so this isn’t a replacement, it’s an addition.

This fragmentation is extremely counterproductive.

If you want to focus on your front-end and lower them to compete with other front-ends, you’ll have to write some passes of your own to optimise, because it’s your own dialect. But other front-ends, doing basically the same things, will have to do the same because their dialect is slightly different.

Furthermore, teams looking at just optimisation passes will have to implement different passes for different dialects (and the different forms they lower to linalg+something), which means most people just pick a front-end / back-end pair and optimise for that.

The whole idea of MLIR’s power to allow those teams working together vanishes. We’re stuck in a rut again.

To me, having a high-level dialect that all front-ends lower to is fundamental for MLIR to become the go-to platform for ML/HPC.

Moreover, the forms that they lower need to be canonical (or made canonical by a pass), and every further lowering should canonicalise before passing it down to another tool, so that we can all work with a smaller set of problems and can do our jobs more efficiently.

Of course that’s easier said than done. We’ve been trying to do that in LLVM IR for a long time and it’s NP-hard. But each front-end lowering to their own dialects is going in the wrong direction.

That said, front-ends can have their own dialects, as long as they convert to a canonical high-level dialect at the end, before handing it in to the next tool / optimising pipeline. So, if we come up with a generic high-level dialect, it doesn’t mean all front-ends have to change to it, just that they need to provide a lowering framework to it.

[Note: for the record, I’m not looking for de-jury canonical forms + dialects, just de-facto ones. Front-ends can do what they want, but if they want pass X to optmiise it, they need to lower to canonical form Y]

2 Likes

I totally agree with the @rengolin point.

It is not easy but this should be the ambition and the shared effort required to try to create a common pre-competitive asset.

Destination passing style, in its most trivial and unoptimized form, will just create a new destination for every high-level operation. So you do not need to be clever or do any kind of analysis (but of course you can and you decisions will be carried forward). What we gain is a uniform way to express shape computations of results (as @_sean_silva also mentioned). It is also helpful when doing tiling, as it allows us to express access patterns to the overall result by tiled computations while still using a tensor based representation.

Once in destination passing style (and accompanied by an interface), we can then go about optimizing things at the tensor level independent of operations themselves.

Destination passing style itself it an old idea, I found this paper a nice read in the context of array programming that describes the allocation + shape functions aspect well.

(Just so you understand my approach in asking, I don’t actually have a strong opinion on this, but I was putting a strong position out there so that we enunciate the need, as you are doing. Canonical high level forms are really hard, and may not be possible, so I want us to be clear on the need/desire here)

This is also a concern we have been discussing. Having high-level operations in dialects at the HLO level of abstraction makes it easier to identify patterns e.g. to decide whether we want to map them to a library. On the other hand, linalg provides a nice uniform interface to do cost analysis on, as it encodes access patterns directly (for memory cost) and has a body that we could use to reason about compute cost. If we want to do the same at the HLO level, we would need to encode the same information, likely via interfaces for access pattern and compute cost.

My current thinking is that we will need those interfaces anyway, for the operations that we cannot map to linalg, and that we’ll want to make the fusion/library decisions at the HLO level before we lower to linalg.

A shared high-level dialect certainly would simplify things a lot and would provide a common starting ground. It still needs to remain an extensible system, though. Staying with my example of fusion: While I believe we could build the infrastructure for an extensible fusion system (using cost models etc.) based solely on interfaces upstream, it will need to be tailored by different compiler implementations and backends, as at least the cost models will differ.

1 Like

Absolutely, I get it. I didn’t mean to propose a perfectly canonical high-level dialect, as I agree, it’s probably impossible from any practical standpoint. My points on it being NP-hard and reaching de-facto agreements were an allusion to that.

What I do mean is that not even trying to re-use the common parts is the wrong direction. This is what I think such a new high-level dialect can be: at the very least a (useful) intersection of existing dialects, but hopefully, a dialect that other dialects can be lowered to (minus special stuff, that goes straight to linalg+scf+std or code-gen).

Otherwise, we’ll continue to see frameworks like TF, XLA, PyTorch, IREE, PlaidML all being special purpose, company-specific, fragmenting the available tools, and making a mess for researchers and developers trying to just run some models efficiently on their hardware.

1 Like

And continuing down the path of asking controversial questions, the two approaches to this are what I call a reduction and a union opset. Examples of the former are MHLO/TOSA, and of the latter ONNX. I think we could also include the framework dialects themselves in the latter.

Which way do we want to go? What about beyond ops? With all of the existing dialects, with the possible exception of torch, we are representing a subset from a capability perspective. I’m less concerned about op deltas, which can always be filled with more ops or an extension mechanism. But if we’re heading down this path, we need to know when we’re ok limiting the expressivity of the platform features.

On the torch side, by lowering to a simple ssa-value based form, we are erasing mutability, but more importantly, we are leaving out possibly the most important feature of its runtime representation: the fact that all of its tensors are strided and that, to a first approximation, PyTorch basically is an engine for turning as much as it can into a metadata operation on those data structures. If throwing that out, that has certain implications, and we should be ok with them (chief among them is that it becomes very difficult for the resulting compiler to efficiently use any parts of the torch ecosystem or kernels – thereby needing to be a complete island). Currently if I had to design such a system, I would need to add the “delegate bits” at the torch level and constrain the compiler stack to just certain code generation activities within that framework. That might be ok, but it is a limitation.

Since Jax basically is MHLO, it suffers less from this problem, but you do see some of the gaps emerge around distribution and state management. Ditto TF but also with more “leakiness”.

We’ve got to abstract over something, but we are going to lose things in the process, or drive certain parts of the resulting design to the frameworks. It seems to me that all of the existing proposals are abstracting over a lot of ground on the torch side, when actually the torch execution model is the generalization.

1 Like

That’s a very good point. My personal take on things as vague and high level as these is to be as pragmatic as possible from the beginning without losing sight of the overall goal, which seems to be is exactly what you’re proposing, too.

I do not have a good answer, unfortunately, but if I had to guess, I’d say there are two paths to this:

  1. We pick one side and limit the other side’s ability to perform, which is what you described above and I am interpreting you mean it as “not the best idea”, because it limits what you can do in MLIR, that you could on the original framework.
  2. We compose. Each front-end does what it can on their own dialects and only lower to the upstream dialect when it has run out of things to do. This is the MLIR way but with a twist, that you know you’re losing information, so it will be impossible to do certain things after that.

The second point was our approach with Verona. The highest level representation had information we needed to do Verona-specific stuff that wasn’t possible in any other dialect, and once we lowered to others, we were literally throwing away any hope of performing those same transformations again. It wasn’t just hard, it was probably impossible.

That is a trade off, similar to the first point, but time-wise rather than space-wise, which I think it allows us to milk the framework a bit more. So this more generic high-level dialect could be less expressive than Torch or even MHLO, designed for mid-to-low level optimisation passes, not high-to-mid level ones, that are still better done at their original dialects.

This is one way to do, not the only way and probably not the best way. More likely part in a combination, even. But I think we should explore a more tiered approach than a catch-all approach.

1 Like

IMO the former does not imply the latter.

For instance, an ML compiler could still “fuse” (grouping may be a better term for this) a convolution and a relu and map that to a pre-fused cuDNN kernel. Doing this on an orthogonal op set (TCP / MHLO etc.) means we don’t need a large set of pre-fused convolution ops (that largely mirror cuDNN) at the Torch / TensorFlow level.

cuDNN is one thing. I think what I was referring to was more interop with the native PyTorch kernels (including those that are user defined). And I was using that somewhat as a strawman for other integration laden things in that vein (although there are real use cases for this precise thing that keep coming up).

(And for the record again, I’m not concluding anything, just poking to try to understand beyond what level certain parts of the problem become misaligned)

1 Like

<wild tangent>
Let me expand on this, as an unrelated tangent, outside of the main discussion, just as a data point, not as a discussion point…

Things like alias analysis or vectorisation in LLVM look at bare LLVM IR and try to find patterns, map those patterns in compiler structures and then look at those (TBAA and SCEV annotations, VPlan, Poly). This is hard to do but there is enough info to grasp some things.

If we lower loops in a canonical form, it’s easier to vectorise than if we need to shuffle basic blocks around, hoist declarations, change induction ranges, etc. If we pass restrict annotations, it’s easier to do alias analysis.

So what I mean by time-wise trade-offs is that we do what we can early on, but when we lower, we try to lower in the most canonical form possible, so that, even destroying precise information (dialect ops) we’re keeping imprecise information (canonical forms, annotation), which may make it possible (but not certain, due to cleanups and transforms) to do the same thing again, later on, on a less expressive representation (dialect).

This is necessary because even having a high-level precise representation, we don’t have the right shapes. For example, inlining exposes many opportunities for other optimisations that we just don’t have before inlining. But it also destroys shape and annotations, etc.

The approach I propose above (and similar to what I proposed to Verona) is to do just that:

  • Progressively lower to less expressive dialects only after doing what we can with the information we have, fully knowing we cannot do it all and giving up the idea that we can do it all.
  • When lowering, try hard to keep canonical forms and annotations, so that a late-pass can still do _again what you did originally, but with less information. This can be the same code but after a local raise of the (now low-level) IR, if it works.
  • When doing cleanups, transforms, etc. try hard to keep canonical forms and annotation, just like we do with LLVM IR.

Perhaps this is my brain after 14 years of working with LLVM. Perhaps this is how it can be done with such a diverse set of front-mid-low ends. But I’m hoping someone has a better idea…
</wild tangent>

Another wild tangent, but then I need to sign off for the day and let both the timezones and my day job catch up.

What if we based this new thing on the torch type system and op interfaces but practiced “ethical non monogamy” with respect to ops? Type systems and interfaces are what have the highest mis-abstraction cost. We should try to get the op sets to something approaching canonical, at least for some 80% case, but beyond that and we need freedom more than we need uniformity.

1 Like

I really like the “MLIR way” description, let’s try looking at this from that perspective a bit more. IMO, it is not really a twist that lowering = losing information, this is the main reason why we originally insisted to much on progressive lowering, which is basically a more conscious approach to discarding information. That being said, converting between abstractions (e.g., between a framework-specific and a MLIR-generic opset) does not have to be a lowering. This ties with @stellaraccident 's comment on reduction opset (lowering) vs union opset (horizontal, reversible conversion), and maybe there is some middle ground where only some information is deliberately lost.

An “more MLIR way” could be a mix-of-dialects: there is a “reduction”-style generic opset in MLIR that frameworks use alongside framework-specific ops for the things that cannot be represented otherwise. This can include wrapper ops that bridge abstractions of the kind we have to connect tensors and buffers, but this comes at a cost for the framework.

An “even more MLIR way” is to consider the certain things frameworks are willing to do on their high-level representation and, if they are common, try implementing them on interfaces instead. This could side-step the question of the common dialect, but may end up being harder to design. Specifically, mapping to library calls would need some sort of marker that an op or a combination thereof is equivalent to a library call with potentially as many interface methods as the library has functions.

Both “more MLIR” ways come at a cost for the frameworks compared to rather happily living in their own sandbox, and the time scale at which the benefits from using those will pay off is not very clear.

2 Likes

The other cost is the added complexity (potentially combinatorial) for front-end agnostic optimising pipelines absorbing those modules to know the equivalence ratio of different dialect ops and their effect on other different dialect ops as operands.

We may be able to use traits and interfaces to common those things up, but we may end up with silly traits like ElementWiseMulLikeOp and ElementWiseAddLikeOp so that I can merge into an ElementWiseMLALikeOp, from tosa.add(mhlo.mul). It’s a silly example, but you get the idea.

I can see this. Designing interfaces upfront is hard, I would be inclined to just try and define them to see how soon (if at all) we will actually run into the problem.

For this example specifically, there shouldn’t be as many fused ops that we would need to support, e.g., to target libraries for the exact same reason we would need a lot of patterns: combinatorial cost. We can also think about this problem as generalizing InstCombine to work across dialects. We can try generalizing traits, e.g. ElementWiseOp that returns the core operation as arith.something. We can also decide that this part is not worth generalizing and each framework can just keep doing it separately because it’s easier, but we will know for sure.

3 Likes

I am relatively new to MLIR, but I see modules containing ops from more than one dialect in the same module/function.
So, why a new dialect, why not just pick and choose whichever ops is best suited from the existing array of dialects and mix them as needed? … and only add new ones if none already exist.
Over time and with experience, the best ones to use will be used the most.
I know that there is a lot of focus on inference models at the moment, but there will soon be be ways to train systems with a single picture instead of thousands, and for that, the ops needed are somewhat different than the existing ones.
I guess what I am trying to say, is things keep changing at the moment, so whatever high-level ML dialect we come up with for this, will need updating and changing for some time to come, before it stabilizes, so we need to have considerable flexibility in the ops we use.

1 Like

I think one important point that is worth considering is that it has to be somebody’s job to decide on the hardware efficient lowering of a bunch of different ops. This is something we can definitively pronounce does not belong in the frontends.

I gave the list above: sort, fft, topk, scatter-with-repeated-indices, qr decomposition, cumsum, embedding bag , “things with >4 dimensions”, “things with data-dependent dimension-sizes” (nonzero, unique), quantized-softmax-without-messing-up-the-final-fusion, etc.

Each of these has multiple different ways it could be lowered, and a LOT of ingenuity and domain expertise goes into it. E.g. for FFT: do you do DIT or DIF? do you have a big enough batch dimension you can use for most parallelism? at what size should you fall back on a DFT on this hardware? Even something simple like “exp minus 1” can require precision considerations etc. (it might be okay to unfuse sometimes in return for requiring backends to support fewer ops, or perf, or whatever).

The skillset required for making those design judgments based on target information and domain expertise is a very distinct island that I think is somewhat orthogonal to various other parts of the community. It sits above codegen and backends like IREE (to name one off the top of my head). And it sits below frontends like Torch-MLIR.

We’ve been trying to shove more of this into the codegen space (e.g. linalg_ext.fft) but in its current form it feels too low-level when you consider the vastness and specificity of different decisions that need to be made. Maybe some day these technologies will be generalized enough to subsume all this, but we need a place TODAY for doing those things which we can then gradually subsume. Actually this could accelerate the generalization of linalg/etc. precisely by showing all the transformations that are needed in practice. Writing these transformations to achieve a certain performance/functional goal is a different skillset from staring at multiple transformations and distilling a set of principles that allow unifying them.

I wonder if we can stake out some space in the “ML middle-end” based on that definition: input: broad coverage of ML operators; output: efficient decompositions/lowerings taking some degree of target hardware/low-level backend details into consideration. The input would not be a single orthogonalized form (but would have enough stability guarantees for frontends – probably with a stability-guaranteed “union” dialect that initially is a light layer of indirection to internal dialects, but will gradually get further from them). The output would not be a single canonical form either. But it is the responsibility of this layer to “impedance match” all the frontend operators to the different backend requirements. And we would build composable infra to perform a variety of transformations that incorporate various notions of hardware efficiency which are absent at the frontend level, and non-recoverable at the level of detail of codegen/linalg. As a rough guess here, the input would be non-destination-passing-style and the output would be destination-passing-style.

Another thing that I would want to be handled at this layer is converting complex numbers to real numbers and emulating data types that are not available on a target backend (e.g. a backend doesn’t support f64 – emulate it, or have a great place to make a policy decision about truncating it). This is one of the biggest pain points when lowering from Torch-MLIR to IREE for example, and begs for a layer to handle it.


So, the Torch dialect is not a good place for this in its current form because the work to get a nicely decomposed and “orthogonal” op set hasn’t happened yet. I mean, Torch-MLIR does do decompositions, but it’s a “get things working” sort of thing and we haven’t (and don’t intend to) do the principled work to orthogonalize and layer that decomposed op set. Frontend-toward efforts like PrimTorch or backend-towards efforts like this thread should be responsible for that.

I think that JAX has opened up a bunch of interesting questions regarding the programming model and interaction with the compiler, which traditional compilers (except in lisp/etc.) don’t bring in. Things like grad/vmap/etc. are really user-controlled compiler transformations, and so naturally live closer to the frontend (or require very careful layering). I don’t have full visibility into this, but I have some intuition that distribution (data-parallel and pipeline) are similar things that intersect with the user programming model, and so the layering has to be carefully considered to really understand what precisely is the information that the frontend has to tunnel down to which layer.

Also, practically speaking, not all frontends necessarily have MLIR as a dependency. Maybe they should, but there are valid reasons not to.

We definitely need a reduction somewhere, because many transformations, including grad/vmap but also SPMD/etc., essentially require writing O(#ops) transfer functions. Because grad is more framework/frontend aligned, I think there is some argument to move the reduction op set towards the frontend (and PyTorch is in fact moving this way, JAX is already there). This then leaves a pretty large space in the stack to put all the other transformations. It would be useful if an expert could explain about JAX and how it separate responsibility for this (grad, vmap, spmd, pipeline parallelism distributed, data-parallel distributed) between itself and XLA.

@jekbradbury perhaps?

The direction that Torch-MLIR is collaborating with PyTorch on has the framework give us graphs with value semantics for those islands where the framework wants to commit to a more comprehensive compilation stack (TorchDynamo is an example where such “what goes in the graph the backend compiler sees” decisions happen). I don’t see a super great compilation story possible with strided tensors – it is too constraining for any significant compilation work to happen; at best (and even then I would ignore it and recover later) it is more of an “inline hint”.

3 Likes