This RFC is a proposal to bring the MHLO and LMHLO dialects into the official MLIR repo and continue their further development upstream. The mhlo and lmhlo dialects provide a lowering path from the MLIR TensorFlow dialect, and there are conversions out of these dialects to lower level dialects already in MLIR; in addition, these dialects also enable getting into and out of XLA proper. Most of the motivation for having these ML dialects upstream was already spelt out here: RAMBLE: How to position new ML dialects in tree but that discussion became more general due to its focus on several other dialects. Here are a few things to note reg. mhlo/lmhlo.
The Tensorflow repository uses these dialects: the MLIR TF dialect is converted to mhlo using a few passes (mostly -xla-legalize-tf, -xla-legalize-tf-control-flow). There are also translators that go back and forth between XLA proper and mhlo. Moving these dialects to the MLIR repo would mean that TF would depend on them through its existing LLVM/MLIR dependency. This frees TF from non-TF-specific MLIR dialects, moving them to MLIR where the latter’s organization, coding convention, and development style naturally fits. (Creating yet another third repository (with mhlo and lmhlo as out of tree dialects) would be too chaotic and cumbersome due to the number of things that will be in play in getting TF to use MLIR: they will all have to align on the LLVM/MLIR version with mhlo/lmhlo sandwiched between TF at one end (one of its main users) and LLVM/MLIR on the other. This option I feel would be worse than what we have today giving people with too many things to look at and track.)
The current infrastructure in mhlo/lmhlo is mostly “dialects ops”, their canonicalizations, and dialect conversions into and out of them, without any elaborate analysis or transformations. In that sense, the code’s footprint is lower than those of dialects with many transforms. There are some transformations that work on or apply to them: eg. copy removal and buffererization, but nicely, these were at some point moved upstream and made to work with op interfaces (for eg. CopyOpInterface) - as such, these dialects transparently benefit from upstream passes.
One concern that came up here RAMBLE: How to position new ML dialects in tree was the impact on upstream maintenance. I’m happy to maintain or help maintain these dialects including aligning them with conventions used upstream. This of course doesn’t reduce change impact when the core changes, but again recall (2).
In terms of what tools should link in these dialects, I think mlir-opt should link in these dialects the way it is now. Although it’s already 30 MB in size on a release build, in the future, these dialects along like linalg, tosa and other things with ops and transforms on tensor types could be removed from mlir-opt and added to a new binary mlir-tensor-opt.
Other thoughts, comments, and suggestions are welcome.
I’ll add one more point for discussion: for the next ~ year at least, we (TF GPU team at Google) need to keep MHLO and LMHLO coupled with XLA’s HloInstruction representation. Will these dialects follow changes to XLA’s representation by fiat, or will making changes to the dialects require additional discussion beyond standard code review?
Leaving aside logistics of how we do this, I’m +1 on this RFC. MHLO has been a defacto lingua franca for the MLIR ML ecosystem for quite some time and it has good connectivity with other efforts. Speaking as both a Google employee and member of the community, I would like to see this piece of infrastructure moved to its proper home upstream and subject to an open development process. I could cite numerous examples, but where it is placed now is causing “sandwiching” issues for several projects which would really benefit from the ability to refer to its operations but find it difficult/impossible to navigate the repo/version boundary issues with its present placement.
I feel strongly enough about this that I would argue that if we can’t converge on the point to move MHLO upstream in some fashion, we should probably make active plans to discontinue its use and develop an alternative (possibly a fork). How it is positioned now is quite problematic to the ecosystem.
As has been alluded to in previous discussions, there seems to be a missing “MLIR-based ML Frontend Integrations” top level project, which would be a natural home for some of these things. I don’t think we need to tie this decision to that one, but we should probably start thinking along those lines.
Nit, I would probably include chlo on the list and seek to collapse it into mhlo (for those who don’t know, chlo once stood for “client HLO” and contains ops that were previously only realized directly in client libraries). I believe it only contains implicit broadcast variants of many ops with conversions to explicit forms. The reason for the separation is largely historic at this point and the refactor could either be done before or after a move.
When I extracted all of this in the mlir-hlo repo, the goal was really to have a path forward to upstream all this. Now I didn’t manage to get any staffing at Google to develop this project (a consequence of Covid and other coincidental internal reorganization, I didn’t even get an intern on this!) and this project isn’t evolving as it could have.
Seeing the lack of traction inside Google to operate on what the aspiration is there, maybe having it upstream would give the opportunity of making it a more collaborative effort to evolve it and build the end-to-end system I’d really like us to have!
There are really 3 dialects in this repository:
chlo or “client HLO” was intended to match the XLA “Builder” API which offers entry points that don’t have a 1-1 matching to the “HLO op set”. It seems like it’d be more reusable to encode in MLIR dialect conversion pattern the same logic that the HLO client builders are encoding right now, expanding these “client ops” into the actual IR.
We could merge it into mhlo, but the practical difference we wanted to make was that mhlo would be the “canonical” form of the IR on which we’d run optimizations. We didn’t want to see chlo → mhlo as just a canonicalization, because the conversion could be made target specific as well. The separation also kind of ensured that mhlo transforms would never create any chlo operations.
mhlo this is really a fairly direct transposition of XLA HLO, as a superset: we wanted the freedom to expand on top of XLA HLO, in particular to experiment with dynamic shapes (on targets that would support it). Ultimately this is really intended to be the core of the compiler and the transforms, and most of it should map to linalg named ops as much as possible (like TOSA does).
lmhlo: this dialect was introduced as a temporary step to help migrating XLA backends (HLO+XLA buffer assignment → lhlo), however it is not intended to have a life of its own beyond being this stepping stone in a transition. So we should not introduce strong dependencies on this dialect and design our compilers assuming that this dialect could go away. Instead the large majority of what is modeled at this level could be represented in linalg already. Passes performing transformations on LHLO should ideally adhere to strict interfaces that exists on Linalg and are general enough.
@sanjoy_das_google Since the translators from XLA proper to mhlo and back would live in TF and in the interest of having those translators continue to work when changes are made to XLA HloInstruction, my understanding is that you’d want to make those changes to mhlo dialect reasonably quickly or ideally at the same time. In the future (beyond the one year you refer to), the process could be made more MLIR-driven?
I echo Sanjoy’s question. This seems like a strange middleground between a sovereign upstream dialect (like linalg) and an upstream MLIR interface to a spec (like TOSA).
@bondhugula The mechanical aspects of actually making the code changes are not problematic. But I want to make sure we have a path to resolve design differences between XLA and MLIR upstream as & when they arise. For instance, what happens when the XLA HLO for convolution introduces a new attribute that does not make sense for upstream users of MHLO?
Overall I like this proposal, but I share Sanjoy and Sean’s concerns here. Unlike TOSA, the HLO proto doesn’t really have a spec and is dependent on TF’s implementation decisions.
One option here would be forking. If we think the MHLO op set is good starting point for the gap we’re looking to fill, we could fork the existing implementation and allow it to diverge from HLO. Having separate dialects for “representation of XLA HLO proto in MLIR” and “Tensor op set inspired by HLO” is actually something we’ve discussed a few times. The current MHLO is kind of a weird mishmash, as it contains ops that are only in the XLA client builder, ops that only exist in the instruction proto, tuple and variadics and variadics of tuples, etc. Instead we could leave an HLO dialect that directly maps to the proto in TF and let MHLO expand by fully embracing dynamic shapes, dropping tuples in favor of variadics, etc. Like Mehdi said, further effort on this hasn’t been able to get prioritization. I think this would avoid the issue of a change to the HLO proto necessitating a change to the MHLO dialect by fiat and could instead be something that gets discussed in the context of what’s best for MLIR as a project.
Of course, this proposal is extra work, and if the TF folks decide that it’s not worth it to them to use this separate representation that we have upstream, I think it loses a lot of the value that Uday is looking for. I do think that this would at least provide some more clarity to what each dialect is trying to represent, whereas it currently feels a bit confused and rudderless. I know in the referenced ramble thread, @clattner expressed some skepticism based on practical experience with the oft-discussed “pure translation” dialect idea, so curious about his thoughts here.
There are quite a few cleanups I’d love to see with MHLO that maybe we could tackle as a community effort IMO it doesn’t currently meet the standard set by the other core dialects.
There’s nothing specific about the improvements you listed (like removing tuples, more clearly isolating proto-compatible and non-proto-compatible stuff) that requires being upstream.
In fact, we have been discussing such improvements on MHLO (which was intended to allow that) since its inception. I personally use that as a guage for the amount of interest there would be in making such changes in a hypothetical upstream MHLO. Simply putting it upstream isn’t going to change people’s priorities, and time has shown what those priorities are for the folks working directly on MHLO.
Also, given that most of these cleanups mostly have ramifications in downstream repos (such as proto conversion code), it’s not like a random MLIR upstream contributor can make them – it would be a TF/XLA contributor in the downstream project, and they are already totally empowered to do that (and it hasn’t happened). So I’m also somewhat skeptical of this move suddenly bringing new “community engineering effort” onto those neglected cleanups.
There’s a number of such engineering improvements that I haven’t seen pick up steam. Off the top of my head:
removing mhlo.copy and replacing it with a memref.copy op,
replacing mhlo.constant with std.constant
replacing the mhlo control flow with scf,
AFAICT, the primary reasons for not making these changes are either considerations foreign to upstream (like protobuf serialization interop), or simple lack of interest/prioritization.
Agreed, Sean. My reference was more to the fact that if the MLIR community likes MHLO enough to think that it’s a good starting point for a gap in the current core dialects, we could fork from the current MHLO without the baggage of future changes to TF requiring potentially undesirable changes to a core dialect.
For those curious, I just did a pass through the mhlo dialect (not chlo or lmhlo) and broke down the approximate op surface area. Takeaways:
My overall takeaway is that the bulk of the dialect is representable nicely with precisely retained high-level semantics using linalg-on-tensors ops. I don’t think it makes sense to upstream these given such a precise overlap (I heard it is planned to have mhlo → linalg-on-tensors conversions, which addresses the ecosystem interop point of this RFC, and also eliminates the need for lmhlo).
The remaining bulk of the dialect consists of “mini-dialects” that are used as part of XLA’s implementation / modeling of certain transformations, which we probably don’t want to upstream without individual design consideration.
Linalg on tensors ops
elementwise ops / otherwise foldable into a linalg.generic with no
reduction dimensions:
Fantastic breakdown Sean and it matches my recent experience. IREE has historically used MHLO as its primary input dialect but is in the process of switching to use linalg-on-tensors (+ std + scf etc) exclusively. Beyond a few operations that would be nice to have some helpers for (like mhlo.convert) there’s been no real issues factoring out and converting the bulk of them. Most of the conversions have been/are being upstreamed into tensorflow meaning that any downstream project of tensorflow + MLIR that consumes linalg-on-tensors (or a transitively reachable dialect) can in essence already import MHLO.
If as mentioned MHLO becomes an “translation dialect” for other non-tensorflow frontends coming into MLIR/LLVM to target I could see including it upstream being a reasonable action so that each frontend did not have to replicate MHLO->linalg, however I also agree with the sentiment here that it would require changes to how HLO is governed, a spec, compatibility provisions, etc (ala TOSA) in order to be a reliable target for non-tensorflow users. The separate MHLO repository that’s been setup for this is an example of a good step in that direction but until considered the source of truth for tensorflow itself isn’t too strong a proof point.
That said, having spent enough time staring at MHLO, TOSA, and linalg my opinion is that forward-looking frontends would be better served going to linalg directly instead of trying to squeeze through any higher-level interface dialect (MHLO, TOSA, etc). For any current user targeting execution on XLA (and thus already squeezing through HLO) the real question is whether they want to use MLIR and whether they would switch to using the standalone MHLO repo (and then LLVM’s version if accepted). If a project is going to be depending on the tensorflow monorepo anyway then there’s non-trivial overhead juggling the multiple sources and versions and that’s why in IREE we’re moving everything HLO related into our tensorflow-specific import pipeline (as until an MHLO repo was the source of truth for tensorflow we need to pin tensorflow+mhlo anyway). Feedback from any such HLO+MLIR project that does not depend on TF would therefore be useful to gauge interest.
I assume you meant lmhlo.copy (because there isn’t an mhlo.copy) - but this is really a tiny issue in the whole scheme of things.
The mhlo dialect is meant to maintain a 1:1 correspondence with the XLA HLOs and so we don’t want to be doing these replacements AFAIU! Doing these replacements would mean you can’t anyway convert from the MLIR TF dialect to MHLO (via -xla-legalize-tf), nor import from or export out to XLA: both of these things in my understanding are needed going forward if not for ever? For eg. the mhlo.while is aligned with tf.while instead of scf.while - the latter is more generalized/powerful in a way that is not needed for TF purposes (or perhaps all ML models – in any case, the option to lower directly to scf.while always exists for frameworks). Another data point while on this: mhlo.while uses tuples which are good for lowering from a more user facing dialect but bad from an IR manipulation standpoint. Neither scf.while nor linalg support tuples nor should they because they are meant for transformations and tuples only get in the way. The sooner you deabstract tuples after entry into MLIR, the better - with mhlo.while, it is easier to lower frontends that would have otherwise had to deabstract tuples on their way into MLIR. With mhlo.while, you get it for free when converting mhlo.while to lhlo.while (tuples were removed from lhlo) or mhlo.while to scf.while which exists.
All of this is moot to me because the way I saw it: the point of bringing the mhlo dialects upstream is not to fill a void in the representational power upstream – but to start with, it is to have as much of the MLIR-based TF lowering upstream (as far as the parts that are not TF-specific go). Do you think the MLIR TF dialect can with little work be lowered directly to the current upstream dialects in MLIR (i.e., straight from the MLIR TF dialect to upstream dialects)? Consider tf.while which is easier by all means to lower to mhlo.while than straight to scf.while: it’s good to have reusable functionality move upstream (tuple deabstraction in this case). Also, do you see value in being able to export out and import from XLA proper into the MHLO dialect in upstream MLIR?
The way I see it, the mhlo dialect should just be directly following XLA HLO in the immediate term. If you add an attribute there, it’s assumed it would make sense to upstream users of MHLO, i.e., MHLO’s design is a reflection of XLA and not the other way round in the short term. Overall, unless XLA is completely going away very soon, I only see benefits in adding a small footprint dialect like MHLO/LMHLO upstream. It also provides an incremental path to migrate the TF to MHLO/LMHLO lowering to TOSA or Linalg. It also provides frameworks currently mapping to XLA an easy path into MLIR that does not depend on TF. (I don’t have the list of frameworks here.)
All of these can be fixed; IMO, there are a few core dialects that don’t currently meet the standard set by other core dialects , both w.r.t design, and documentation and code quality. They still continue to evolve upstream and will hopefully get better.
Actually it is meant to be a superset of what can be expressed with XLA HLOs, and we’d like to keep the ability to go from TF->MHLO->…->XLA HLO ; but it does mean it has to maintain a 1:1 correspondence.
I agree with the rest of the points you’re making here in general and just wanted to point this one out: MHLO is already diverging slightly, and I’d like to keep the door open on this.
Converting MHLO to all the dialects/alternatives I listed above either exists or is desired AFAIK. Like Mehdi said things are structured the way they are now to keep the downstream-specific path TF->MHLO->…->XLA HLO functional. That’s a highly downstream-specific use case.
And FWIW, converting directly to all the alternatives I mentioned would not be much harder. The code analogous to TF->XLA in npcomp goes directly to the upstream alternatives I mentioned and it is working very well (we actually started with a dialect analogous to mhlo and then removed almost all the ops in favor of linalg on tensors). Also, the alternatives are often easier (e.g. this code that would be quite trivial if it used scf.for).
There may be reasons to keep mhlo.constant and others: MHLO/HLO are intended to support more than just what the op may naively describe. For example XLA has some amazing features around multi-devices support, sharding annotation, etc. All these may require some specific attributes to be preserved.
The XLA dynamic padder is another very interesting feature that I’m not sure we can easily support at the moment: it may require another type than our current tensor (an attribute on tensor may be enough though), or some attributes on the ops themselves (limiting the interactions with the rest of the ecosystem though).
Yes, those features are pretty cool, and indeed special ops are likely needed. What I’m unsure about is that such changes seem to be part of larger designs (for distributed tensor support or building limited dynamic shape support on top of a static shape compiler), which probably need individual discussion beyond the scope of this upstreaming effort.
I think the ecosystem interop design point is really the main reason to consider this RFC. I would really like to hear more from folks using mhlo in their downstream projects that would be benefitted by this RFC.
I don’t think this considers the point I already made on while op conversion in post #13 as an example RFC: Bring MHLO and LHLO dialects into MLIR upstream - #13 by bondhugula
Are you able to cover all major areas of ops in npcomp TF lowering: for eg. lowering tf.while, tf.if which use tuples? You’d face tuple deabstraction issues which is one thing that bringing mhlo will address or provide an incremental path to solve.
From my perspective, the case is not made on the strengths or weaknesses of any of the individual features of something like mhlo but in the connectivity benefits it might bring to the ecosystem and the extent that someone (most likely Google, given the coupling) is willing to invest in the maintenance, for the benefit of all.
Entrypoints are important, and while the mlir-ml ecosystem may eventually have discrete investments in this area on its own, given where we are in the evolution, I think it makes perfect sense to borrow and lean on those that already exist and have some mass. I would like to see better pride of placement for such assets in the codebase.
Given this discussion, I’m still +1 based on these criteria but would like to hear more about the actual connectivity this enables and who is investing in the maintenance to get to a solid decision point for myself.