[RFC] Proposal for a high-level ML dialect in MLIR

Yes it is, particularly in mid to low level IRs in our experience - for custom accelerators this is usually something to do with post-bufferization bespoke memory placement and movement, custom/heterogeneous scheduling or a combination of the two. There’s typically no immediate impetus to make that public.

+1. In a production setting the immediate goal is about wanting to get the necessary skeletal pieces together quickly in a modular form. My own view is that good constructs would reveal themselves as multiple teams creatively play with the blocks, and the modularity/reuse potential of MLIR dialects then enables others to iterate towards those best designs too.

1 Like

Sorry about that. We’ve been having a lot of these discussions in torch-mlir contexts and I forgot to redefine my terms when switching audiences:

  • CHLO : Historically “Client HLO” and consisting of named ops that reduce to the more core MHLO ops (but can also be lowered individually at the higher level if the platform prefers).
  • MHLO : A mapping of the core/historic XLA op set to MLIR.

In parallel, there is an effort underway to pick the ops/types out of these dialects into a new StableHLO project as an exercise to produce a more canonical operation set (CHLO/MHLO were grown over the past decade and we would like to reset a bit if extracting and using as a common interface layer). Historically CHLO/MHLO have been part of the XLA compilers, and they are therefore hard to reason about turning more into a neutral asset. The new work is attempting to correct that situation and also focus the opset more explicitly up at the framework intersection point.

My main interest in all of this is that, at some point, something at this level gets distilled as a community-governed asset and connected with high fidelity to the infeed frameworks and outfeed compilers. If existing assets like MHLO/CHLO (via the StableHLO project) are useful towards that end, then my primary goal in mentioning it was to make sure that people understood that Google would be willing to align on this and (likely) contribute substantial engineering engagement. I am sensitive to both the baggage and bias, however, and trying to stay on the side of “informing” vs “advocating”.

3 Likes

Thanks for the definitions and the links, I’ll check them all and compare to what we have internally.

Absolutely agree. We already have enough of those in other projects to be able to, at the very least, draw sufficient inspiration to have a common dialect. And I believe such a dialect, as the OP suggests, would be a big win for the community.

As @sanjoyd said, contributions from the existing projects are very welcome, as long as they’re reviewed and contributed incrementally.

I did interpret as “informing” (and as an offer, not a push) originally, so didn’t mean my reply as “telling off” or anything. Sorry if it seemed that way.

But I also missed a lot of the context, and was trying to restart from the point where I dropped.

I noticed that when some of the core developers start interacting, it’s a rapid-fire that means outsiders lose interest (I’m guilty of that in other parts of LLVM), so just raising awareness that this happens and to keep it generic so that others can participate.

On the practical side, I think it would be nice if @raghavanr could compile a list of ops that would fit their requirements and what from existing ones (TOSA, MHLO, TCP) would satisfy those conditions. We all then look at that and see which ones satisfy our conditions and what we have to contribute, then we’ll be in a better place to decide what the dialect will look like and where it’ll come from.

2 Likes

Potentially also, I know that @burmako has been doing pre-work on StableHLO in order to document and organize the view from the Google side. Since we all seem to be about to do a “how did you spell gather?” level discussion, I assume that more docs/analysis will help. Perhaps we can publish some of our docs.

1 Like

The ops in MHLO should satisfy our requirements to begin with. If StableHLO is community owned, we should be able to evolve it if there are any missing pieces. Also, @burmako already mentioned there are plans to support all the stuff we need, like dynamic shapes, quantization, sparsity, in StableHLO. If StableHLO charter is different, we can propose a design sketch in the Aug 18 meeting.

Here’s a categorization of MHLO/CHLO ops that is based on this pre-work. It doesn’t aim to be super precise (we found that corner cases are fairly tricky to categorize), but hopefully it’ll be useful for this discussion. This categorization is inspired by @_sean_silva’s prior work.

1) MHLO ops which are currently used by a union of JAX, PyTorch and TF integrations and look fairly reasonable. This doesn’t mean that these ops should be set in stone (e.g. @jpienaar mentioned that gather/scatter are notorious for their complexity, and @sjarus talked about TOSA’s experience with broadcast), but overall these ops seem like a good starting point for discussion.

Control flow: case, if, while.
Data movement: broadcast_in_dim, concatenate, gather, pad, reshape, reverse, scatter, slice, sort, transpose.
Elementwise: abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tanh, xor.
Miscellaneous: constant, iota, optimization_barrier, return, rng_bit_generator, rng.
Reduction: convolution, dot_general, reduce, reduce_window, select_and_scatter.

2) MHLO ops which are currently used by a union of JAX, PyTorch and TF integrations and would especially benefit from design work. E.g. how/if should multiple approaches to modelling distributed execution coexist? Would it make sense to merge dynamic ops into regular ops? What’s the right approach to extensibility, and should ops like cholesky, fft and triangular_solve use the extensibility mechanism instead of being standalone ops? Etc.

Distribution: after_all, all_gather, all_reduce, all_to_all, collective_permute, infeed, outfeed, recv, reduce_scatter, replica_id, send.
Dynamism: compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, dynamic_slice, dynamic_update_slice, get_dimension_size, real_dynamic_slice, set_dimension_size.
Extensibility: custom_call, get_tuple_element.
Miscellaneous: cholesky, fft, triangular_solve.
Quantization: uniform_dequantize, uniform_quantize, <no other quantization-specific ops at the moment, instead using existing ops with quantized element types from the Quant dialect>.
Sparsity: <no sparsity-specific ops at the moment, instead using existing ops with sparse tensor types from Sparse Tensors in MLIR>

3) MHLO ops which don’t seem to be core to an ML framework / ML compiler interface - either because they can be reasonably decomposed into other ops or because they are private to XLA (i.e. aren’t created by existing producers and are only created inside XLA itself).

Decomposable: batch_norm_grad, batch_norm_inference, batch_norm_training, broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, trace, unary_einsum.
Private to XLA: add_dependency, bitcast, copy, domain, fusion, partition_id, tuple, xla.rng_get_and_update_state.

4) Finally, there are CHLO ops which are somewhat of a miscellaneous collection at the moment, but there is some meaningful categorization.

Data movement: top_k.
Dynamism: constant_like, dynamic_reshape, minimum_broadcast_shapes, rank_specialization_cluster, rank_specialization_cluster_yield.
Elementwise: acosh, acos, asinh, asin, atanh, atan, bessel_i1e, conj, cosh, digamma, erfc, erf, is_inf, is_neg_inf, is_pos_inf, lgamma, next_after, polygamma, sinh, tan, zeta.
Implicit broadcasting: broadcast_add, broadcast_and, broadcast_atan2, broadcast_compare, broadcast_complex, broadcast_div, broadcast_max, broadcast_min, broadcast_mul, broadcast_next_after, broadcast_or, broadcast_polygamma, broadcast_pow, broadcast_rem, broadcast_select, broadcast_shift_left, broadcast_shift_right_arithmetic, broadcast_shift_right_logical, broadcast_sub, broadcast_xor, broadcast_zeta. broadcast_in_dim, concatenate, gather, pad, reshape, reverse, scatter, slice, sort, transpose.

5 Likes

Yes, you are completely right, particularly “words matter a lot”. I didn’t realize the sensitivity here and didn’t mean to undermine all the hard work being put into improving things in tree. I apologize in a very heartfelt way if I offended anyone!

Also for the record, I wasn’t (and am not) opposed to having this content in the LLVM monorepo, I just think that it being mixed into core MLIR causes confusion.

-Chris

2 Likes

It is also worth pointing out that for dynamic shapes, tensor.from_elements is needed to materialize tensor<Nxindex> representing tensor shapes.

Having an implementation like @herhut suggests makes it easy to play and to create alternative flows.

Not only outsiders, whoever wasn’t there in the first two hours of the conversation…

4 Likes

Thanks, Chris. I appreciate you saying that.

(We should really continue the discussion on how to split things up so that we have more room for active efforts without overlapping and adding taxes to each other – and see if that can be driven to an actionable next step. I will admit that while super important to figure out, it isn’t exactly a topic I wake up on Monday morning and decide I have the energy to put into driving consensus on. That is part of why of late I’ve been leaning more into the incubator or domain specific repos direction. It isn’t for a lack of desire to fix the current organization: mostly it is just not seeing how to actually carry something out there given all of the dynamics)

1 Like

Hey folks, in Torch-MLIR we have lowerings into Linalg-on-Tensors, TOSA, and MHLO, so I thought that maybe I could provide some perspective here. Sorry for the long message.

Torch-MLIR’s 3 backends

Linalg-on-Tensors

Linalg-on-Tensors (which is really linalg itself + tensor, arith, math, etc.) was the first backend we added, and it is still is the only one that seems to have a principled and complete “abstraction layer” that it sits at in the presence of dynamic shapes. The things I think really contribute to the solidity and coherence are:

  • it takes a hard stance on requiring ranked tensors, which allows all shape stuff to be managed as variadic index operands (which plays super well with CSE and other optimizations to prove relationships between shapes).
  • It consistently permits any dimension size to be dynamic.
  • it takes a “hard undefined behavior” stance on shape mismatches, making it clear whose responsibility it is to actually do shape checking if a safe programming model is desired
  • “dynamic size 1 broadcasting” is blanket disallowed (as I said here this is actually not a big problem in practice for us)
  • The consistent outs handling (“destination passing style”) makes the whole system highly coherent and has “no surprises”.
  • It plays well with “arith” and other dialects and doesn’t “force everything to be a tensor”, leading to a natural division of responsibility that avoids losing information, such as that an integer is a “host side” scalar integer.

Note that dynamic shapes fundamentally changes the layering of ML compilers – it takes what was roughly one abstraction layer and reveals subtleties needed for proper compiler layering, such as modeling of reshapes needing to be more “symbolic” like tensor.collapse_shape vs just “an op that takes the new shape”, or needing to reason about what happens when shapes mismatch (and at what point in the compiler do you ensure those are either guarded or “yolo assume they are not going to happen”). Both TOSA and MHLO have various levels of incompleteness/unprincipledness here, with MHLO functionally supporting more dynamic shape stuff than TOSA.

TOSA

TOSA was the second backend we added, and remains preferred by many users (especially “hardware” or “hardware-adjacent” folks):

  • It is tied to a spec with a really clear “ISA-like” expository style that resonates with a lot of folks
  • The coarse-grained named-op approach is a good match for the many compilers that are designed that way
  • It has really good support for quantization / integer data types.
  • It has clear versioning/stability guarantees on the op semantics.
  • It is extremely solid with static shapes (and many of its users only care about static shapes, so that’s fine)

MHLO

MHLO is the third backend we added, and it offers a reasonable blend of the other two

  • it is a coarse-grained named-op approach
  • it has a pretty clear spec for most of the ops (with a bit of mental translation and hoping that MHLO is the same as HLO): Operation Semantics  |  XLA  |  TensorFlow
  • it functionally supports dynamic shapes (though not as coherent and consistent as Linalg-on-Tensors, and the dynamic shape support falls outside the wonderful HLO docs above)
  • It appears to be pretty tied to HLO (which is highly mature) so most of the op surface area doesn’t change too much
  • It has a different set of principles than TOSA which tend to make it more expressive at the cost of having a larger abstraction gap from hardware. For example, TOSA limits (for highly considered reasons) the number of dimensions that certain operators can handle to 1D-4D, when from a purely algebraic perspective there isn’t a good reason to not be more general. As was brought up in this thread, the handling of more general forms of reduction and scatter also falls into MHLO nicely while TOSA’s principles tend bias it away from that.

Major design axes

Based on this experience, I would suggest that the following are the major axes

  1. Dynamic shapes support

    • Unranked
    • Known rank, but with arbitrarily dynamic dimension sizes. (and putting thought into how to layer the “dynamic size 1 broadcasting” situation and semantics in case of shape mismatches)
    • Static shapes only
  2. DType support: do you use signless integer types or non-signless? TOSA and MHLO use non-signless integer types, and both are super sketchily (even buggily) implemented in this regard with various conversions assuming “signless == signed” and such. And what about type promotion? Or dynamic dtype?

  3. do you use a “combinator”/“payload-carrying” approach like linalg? That has a lot of representational compactness benefits, but TBD if it makes sense at all abstraction levels. And then you have to define what goes in the payload (arith?). This also ties into how much and what kind of “named ops” you want.

  4. Do you use Destination Passing Style (“outs”)?

  5. Stability guarantees / spec / stable binary format / project integration story.

  6. Scope of ops: do you only allow certain kinds of structured ops like linalg? or certain “sufficiently simple to efficiently map to hardware” ops like TOSA?

  7. Do you “force everything to be a tensor”, even that i1 that you’re branching on?

  8. Do you have your own tensor type? Or do you reuse the builtin tensor type?

What would work really well for Torch-MLIR:

If you want to ask me for my personal opinion, here is what I want to target from Torch-MLIR:

  1. (Strong opinion): Known ranks, arbitrary dynamic dimension sizes. Dynamic size-1 broadcasting and shape mismatches are UB.
    • Rationale: Torch-MLIR already needs to handle all the shape errors itself anyway (the exact error semantics and error messages are very delicately related to the frontend op semantics). We actually need to at least prove the rank of ops for to properly infer element types (sad story, but reality). Also, the direction I’m seeing on the frontend side across the ecosystem seems to be heading towards directly producing known rank, dynamic dimension-sized code directly. I really struggle to think of a frontend problem that would require unranked calculations but not dynamic dtypes (for example, doing codegen for an isolated, general “elementwise add” op in PyTorch you would need to multiversion across both shapes and dtypes). And I really don’t think we want dynamic dtypes.
  2. (Strong opinion): Signless with explicit “extend”/“trunc” ops
    • Rationale: Same as “arith” and LLVM. Lowering from frontends is a good place to resolve this out from the types (where it tends to start in the actual user code) into the ops.
  3. (Weak opinion). Use the combinator approach with an appropriate orthogonal “scalar opset” (could be in same dialect)
    • Rationale: The combinator approach has proven to not really be a big problem in the linalg backend, and even allows some things to be done super nicely. E.g. createLinalgPayloadCalculationForElementwiseOp really cleanly handles unary, binary, ternary ops including ones with random additional “scalar” modifiers and all the dtype conversion/promotion semantics.
  4. (unsure). My gut is that it’s superfluous at the Torch-MLIR backend level of abstraction, but there’s a certain elegance and consistency to it that I really like. Note: We usually need to reify the output shape (which is what you need to materialize the outs init tensor) in the case of dynamic shapes anyway (to do the error checks), so in general the init tensors aren’t too much of a burden. But frontend direction seems to be leaning towards frontends somehow ensuring that we (Torch-MLIR) won’t need to emit error checks ourselves :person_shrugging:
  5. (Moderately strong opinion): This is something that I don’t think any of the three backends do “awesome” right now. What I want as a frontend is a handful of files I can drop into my project (FooOps.td, FooOps.cpp, etc.) which I can update “at my leisure” (say ~3 months) and which allow me to produce the Foo stable binary artifact. I think all three backends can be evolved towards this from different angles (even linalg). However, it’s not clear if this is at odds with any other requirements like enabling certain transformations to be written on the IR.
  6. (Moderately strong opinion): It should “just work” for all the ops that the frontend naturally has, no matter how painful it is to support, even if this leads to various internal lowerings/layerings further down the stack (ideally this can be made pretty composable/progressive so not ultimately that painful).
    • Rationale: I don’t want it to be “Torch-MLIR’s job” to decide how to classify, implement, orthogonalize, or layer the efficient, hardware-specific lowering of sort, fft, topk, scatter-with-repeated-indices, qr decomposition, cumsum, embedding bag, “things with >4 dimensions”, “things with data-dependent dimension-sizes” (nonzero, unique), quantized-softmax-without-messing-up-the-final-fusion, etc. There needs to be a really easy path for us to lower all of these ops into something that doesn’t lose information, even if that means there is “more work to do” lower in the stack where more target information is available.
  7. (Moderately strong opinion): Don’t “force everything to be a tensor”. You already need “true scalars” for my strong opinion above on the handling of dynamic shapes. in Torch-MLIR, it would be information loss to lower !torch.bool to tensor<i1>, since it would require “raising” later to recognize that it is a “host scalar” rather than a “device scalar”.
  8. (weak opinion) Have your own tensor type.
    • Rationale: We’ve done this in Torch-MLIR (link) and it paid huge dividends for our core computational data type to be something we fully control (it’s not that much code either). I’ve seen similar stories on the IREE side as well. MLIR’s type conversion infra is a known quantity these days (see my talk) so there’s no reason not to. And in places where this hasn’t been done like Linalg, TOSA, and MHLO I’ve seen it cause problems, though it seems “survivable”. E.g. the “abstract tensor” situation in linalg or off-label use of signless types in TOSA and MHLO. And I’m sure there are numerous latent compiler crashes from feeding !builtin.tensor<!my.random_type> to various passes – why even allow it?
14 Likes

Thanks for the detailed writeup @_sean_silva! Couple of questions below:

I could be missing some nuance (CC @herhut @matthias-springer) but I assume this means when lowering from Torch-MLIR we’ll need to make a “best guess” for where the outputs of every operation should go. For operations like scatter or gemm_with_bias_add this output tensor has a natural candidate but simpler operations like reductions or cwise operations don’t – they will require an analysis conceptually similar to bufferization to make this determination. Is my understanding correct?

Is this because you want to store graphs long term or because you don’t want to have to continually adapt Torch-MLIR?

@_sean_silva Thanks for those excellent points. It is very useful to know these design points from the perspective of Torch-MLIR. I, personally, strongly align with several of these.

This is one I would like to debate more on. While it is clear that the translation from signedness in types to signedness in ops should happen at some point in the stack, it is not clear if doing it immediately after the frontend is the right place. At the level of ops that we are proposing here, it seems more prudent to have signedness in the types.

+1 Strongly agree with this. But, even if we strive for this, there might still be some frontend ops that may not map to any of the ops in this dialect. The goal should be to minimize this set of frontend ops as much as possible.

Thanks, Sean, for such a thorough recounting of learnings from the torch side. When we started that, I was hoping that it would be successful in staking out a contrasting design space.

One of the things I keep hearing is the desire for people to be doing higher level algebraic or dataflow-level transformations. The lower we go in the stack, the harder such transformations become (because they have had certain concepts lowered into forms that are comparatively hard to match and make “large” changes to). But conversely, because the funnel is tighter at the low level, if it can be done there, it can save having to do it on N frontends. I see that tension play out basically daily.

Linalg isn’t great for this: all of the explicitness you mention in its design makes it hard on this axis. I wouldn’t say I’ve really “enjoyed” using MHLO for such tasks. The TF dialects get used for this extensively. And on the Jax side, the highest impact transformations are actually happening at the JAXPR level (imo). When I look at the Torch dialect, I see it as sharing the same potential, but given how young it is, I’m not sure how much is being done here (i.e. to my knowledge, no one has written “optimizers” or distribution transforms at the torch dialect level yet).

Are there thoughts on where these tasks live in the layer stack of dialects? Is this something that we should expect a reduction dialect like MHLO/TOSA to be good at, or should we just be driving these workloads to the framework level dialects and accept that we live in a fragmented world and there will be duplication? And that the ability for the framework level to service its users in this way does play in to “how good” it is at doing its job?

I ask because if we factor those requirements out of the common dialects we are discussing, then what is left is purely lowering and integration – and that is a differently constrained problem that may be easier to solve.

3 Likes

@stellaraccident, this is a good summary of our thought process as well. We’re looking at it from a transformation point of view and linalg really gets too low level to be efficient in the search for patterns that would be easier and often trivial on a higher level pass.

Not that linag is bad for transforms, it isn’t. It’s great for many things. But there are two main things that get in the way:

  1. Linalg isn’t complete (by design): A single high-level op can be lowered into a complex set of linalg + something patterns, which makes understanding semantics much harder than just looking at the ops, but also potentially destructive, because we’d have to know what that something is (which can be different across front-ends) and now you have a much larger problem space.
  2. Other passes operate on that something (and potentially linalg): Between lowering and our transformation passes (which may not be contiguous), there can be a number of other passes and cleanups that can destroy the canonical forms, making it harder (or impossible) for us to find them.

All that is well know and one of the core reason for MLIR’s design, and for that reason I believe we need a higher level dialect to operate on. We can (and probably will) have multiple passes on that high-level dialect, and on linalg, etc, so this isn’t a replacement, it’s an addition.

This fragmentation is extremely counterproductive.

If you want to focus on your front-end and lower them to compete with other front-ends, you’ll have to write some passes of your own to optimise, because it’s your own dialect. But other front-ends, doing basically the same things, will have to do the same because their dialect is slightly different.

Furthermore, teams looking at just optimisation passes will have to implement different passes for different dialects (and the different forms they lower to linalg+something), which means most people just pick a front-end / back-end pair and optimise for that.

The whole idea of MLIR’s power to allow those teams working together vanishes. We’re stuck in a rut again.

To me, having a high-level dialect that all front-ends lower to is fundamental for MLIR to become the go-to platform for ML/HPC.

Moreover, the forms that they lower need to be canonical (or made canonical by a pass), and every further lowering should canonicalise before passing it down to another tool, so that we can all work with a smaller set of problems and can do our jobs more efficiently.

Of course that’s easier said than done. We’ve been trying to do that in LLVM IR for a long time and it’s NP-hard. But each front-end lowering to their own dialects is going in the wrong direction.

That said, front-ends can have their own dialects, as long as they convert to a canonical high-level dialect at the end, before handing it in to the next tool / optimising pipeline. So, if we come up with a generic high-level dialect, it doesn’t mean all front-ends have to change to it, just that they need to provide a lowering framework to it.

[Note: for the record, I’m not looking for de-jury canonical forms + dialects, just de-facto ones. Front-ends can do what they want, but if they want pass X to optmiise it, they need to lower to canonical form Y]

2 Likes

I totally agree with the @rengolin point.

It is not easy but this should be the ambition and the shared effort required to try to create a common pre-competitive asset.

Destination passing style, in its most trivial and unoptimized form, will just create a new destination for every high-level operation. So you do not need to be clever or do any kind of analysis (but of course you can and you decisions will be carried forward). What we gain is a uniform way to express shape computations of results (as @_sean_silva also mentioned). It is also helpful when doing tiling, as it allows us to express access patterns to the overall result by tiled computations while still using a tensor based representation.

Once in destination passing style (and accompanied by an interface), we can then go about optimizing things at the tensor level independent of operations themselves.

Destination passing style itself it an old idea, I found this paper a nice read in the context of array programming that describes the allocation + shape functions aspect well.

(Just so you understand my approach in asking, I don’t actually have a strong opinion on this, but I was putting a strong position out there so that we enunciate the need, as you are doing. Canonical high level forms are really hard, and may not be possible, so I want us to be clear on the need/desire here)

This is also a concern we have been discussing. Having high-level operations in dialects at the HLO level of abstraction makes it easier to identify patterns e.g. to decide whether we want to map them to a library. On the other hand, linalg provides a nice uniform interface to do cost analysis on, as it encodes access patterns directly (for memory cost) and has a body that we could use to reason about compute cost. If we want to do the same at the HLO level, we would need to encode the same information, likely via interfaces for access pattern and compute cost.

My current thinking is that we will need those interfaces anyway, for the operations that we cannot map to linalg, and that we’ll want to make the fusion/library decisions at the HLO level before we lower to linalg.

A shared high-level dialect certainly would simplify things a lot and would provide a common starting ground. It still needs to remain an extensible system, though. Staying with my example of fusion: While I believe we could build the infrastructure for an extensible fusion system (using cost models etc.) based solely on interfaces upstream, it will need to be tailored by different compiler implementations and backends, as at least the cost models will differ.

1 Like

Absolutely, I get it. I didn’t mean to propose a perfectly canonical high-level dialect, as I agree, it’s probably impossible from any practical standpoint. My points on it being NP-hard and reaching de-facto agreements were an allusion to that.

What I do mean is that not even trying to re-use the common parts is the wrong direction. This is what I think such a new high-level dialect can be: at the very least a (useful) intersection of existing dialects, but hopefully, a dialect that other dialects can be lowered to (minus special stuff, that goes straight to linalg+scf+std or code-gen).

Otherwise, we’ll continue to see frameworks like TF, XLA, PyTorch, IREE, PlaidML all being special purpose, company-specific, fragmenting the available tools, and making a mess for researchers and developers trying to just run some models efficiently on their hardware.

1 Like