RFC: remove arith/math ops on tensors

From my perspective, this has been in place for years, and I am not feeling a burning need to “fix” it at the expense of key, high profile uses without a good way forward. So I don’t consider this an emergency… But would be good to arrive at a workable plan for.

I do think that we should discourage new uses of this form of the way it is done is in dispute.

2 Likes

Thanks for the example! I agree this is convenient (and this is why things were done this way originally), and this is why I was referring to a tradeoff here: how many patterns benefit from this (and could be templated if they had to match two ops?) vs how many transformations have to check which subset of the op they are actually processing? How many patterns and transformations are actually broken because they don’t handle tensors correctly?
That reminds me also that we got regular request going in the other direction: arith.add and other to accept “any types” so that users downstream can reuse arith with their own types (tensors or others)!
(which gets in a rabbit hole that leads to whether “add” should be an op interface (terrible idea, I’m not getting there)).

I know when std dialect was split, we explored the idea of defining “container” types that can be iterated and a dialect that works element-wise on such container, but it’s hard to make something both widely applicable and actually well defined enough that it is usable in practice (basically users gets into “why it does not just do what I want?” which is somehow your case with respect to linalg or tosa…).

Okay, so there is at least a use case for this feature, good to know. My understanding was that the ex-core team was unanimously considering this a design mistake.

For context, just updating the ODS to drop support for tensors in FloatLike and SinglessIntegerLike type constraints breaks 26 upstream tests. More than half of those are not load bearing, and only test that nothing happens when given an arith-on-tensors op. There are some canonicalizer and folder tests, but that’s it. So I would claim this is not used or properly exercised upstream.

For more context, I posted this after answering for the n-th time the question along the lines of “why doesn’t arith-to-llvm convert my addf-on-tensors to llvm? mlir is broken!”. This is the kind of complexity I’m concerned with: one has to understand that arith-on-tensors is different from arith-on-non-tensors to be able to use it. By “use” I mean being able to produce a binary, which is what most users want in the end. Ultimately, it’s a tradeoff between supporting specific downstream users (one so far) and cognitive load on the entire community. I’m not saying we should necessarily push this down to Triton forking the dialect, but I do think there is a better balance to be found between the needs of downstreams and upstream coherence.

I’d like to understand more how this is used. Quickly scanning through the Triton codebase, I only see a handful of matchers that use arith ops. Most of them look like they could rather easily be updated to, e.g., isa<arith::AddFOp, tt:TensorAddFOp> or a use a couple of helpers with a TypeSwitch, which would be simpler than both heavy templating and forking the dialect.

1 Like

I don’t have a strong opinion here, but I’d like to convey my thought behind this problem.

On one hand using arith / math tensor ops are handy in not having to generalize code (but you still need to check the type, shape and element type every time). On the other hand, without formal semantics for tensors (dynamic shapes, unknown rank, type cast/broadcast/reduction semantics etc.) you get “implementation defined behavior” and downstream implementations diverge from upstream and themselves.

It’s not because it works for some cases that it’s safe, as @mehdi_amini alludes to above. The main decision is between a safe arith on tensors and no tensor support, not between what we have today and no tensor support. So the result of this thread should not be to not do nothing.

We started adding arithmetic and maths operations to linalg to fix it in the other way, which I believed was the consensus until the Triton case. So I want to make sure we’re not missing anything.

IIUC, Triton uses pointers and offsets as tensor representations, so the operations themselves can be thought of as on the “memory that those pointers refer to”, which is already odd in the usual MLIR progression, but should not harm the tensor story. Now, if there is a tensor usage and that gets reduced to loop on pointers, then this is the clear case for linalgloops+arith.

Also, the Triton transform example above actually matches on mlir::Operation, not mlir::arith::SelectOp, so you still wouldn’t need to generalize the function signatures (not that hard to do either as @ftynse suggests). The main problem there is that you have to pick the information from different APIs if they’re ops in different dialects, but that’s also not too convoluted (and you can have facade helpers to extract information).

Either way, you’ll have to support shape analysis, dynamic types, etc. whether you use tensors, memrefs or pointers and offsets, it’s just a matter of where the logic lies.

Whatever the result of this, we really should have formal semantics for shapes and casts, or downstream dialects will start to diverge more and more and it will be harder for us to fix this as time goes on.

1 Like

Correct there is no path to produce a binary, this is a path that can be added if this is the main concern but it might look artificial. The hard part is that bufferization and distribution are highly opinionated transformation so making the Triton one live upstream is not easy. We could use the upstream bufferization but then it would not be representative to any real end to end flow.

Of course it can be done but having duplicated IR would still have a high cost in my opinion, we would have to make sure both IR are always in sync, have the same semantic, folding, etc… My gut feeling is that it would negate the advantages of using arith dialect and forking would end up being the practical solution.

We are talking about elementwise operations, they require shapes to match for all operands and results and each element can be calculated independently. How does dynamic shapes/unknown rank/broadcast/reduction matter in this case and what make them unsafe?

This is the consensus for linalg based path but it is significantly different than what Triton does and I don’t think there has be a single solution to solve both problems here.

That’s not an accurate description. Triton has pointers and explicit load with memory semantic but this unrelated. Tensors are used to represent intermediate result that may get materialized later either in register or in “shared memory” which is temporary memory disjoint from the one used by load/store with pointers.

I agree but again this feels unrelated to supporting elementwise arith operations on tensor.

1 Like

I don’t quite get what you mean by “it can just be added”: the issue is that this is mixing two level of abstraction and we have a pass “arith-to-llvm” which should be just dialect conversion but it cannot intrinsically support tensors, there is no way to turn this into a full pipeline.

Yeah, I agree. Arith or whatever it becomes is at a very different level of abstraction: not only are tensors undefined (except by lowering) with respect to memory, they imply allocation and deallocation for general purpose use (not saying there can’t be some special uses of them as an abstract type that are much more restricted, just what they are in generality). This is very different from what arith is or can be with respect to llvm.

I think the Triton case is an example of using the tensor type system to represent abstract expressions of some algebra and doing that without a fixed idea of what it lowers to. This is spiritually close to what TOSA and the *HLOs do. But these op sets are not practically close or well aligned, based on my read.

In contrast, linalg takes the abstract tensor type system and lowers it in a specific way with respect to an eventual memory representation and allocation policies. This is pretty different from what Triton is doing.

This conversation does convince me that we really do need to fix this, and I expect Triton (and torch-mlir and others) will have to adapt (now or eventually when arith bitrots beyond use, which is where we are headed if we don’t clean it up and align it with the rest of the system). The question is how.

The tensor dialect sits at the level of abstraction desired: it expresses operations on the purely abstract datatype prior to it being embued with a lowering strategy (which is what linalg and even DPS are). As I suggested upthread, a new op here like elementwise_map would fill the gap and let the vocabulary of arith/math/whatever be expressed in an abstract way. Likewise, that one op in Triton would cover the case as well if we don’t think this use case needs to be expressed upstream.

This would make the IR completely unusable and unreadable since almost every op would end up being an element_map in triton without providing any extra information except that the operands are tensors. The point of using Arith is to be able to have an IR that works on the same way on both tensor and scalar types.

My suggestion is to make elementwise Arith accept any type, there is a need for that as mentioned by @mehdi_amini upthread and for the problem of Arith having multiple uses it can be solved by using LLVM dialect for the low level cases. Arith would then become a dialect that does arithmetic on abstract data independently of representation described by @stellaraccident. I think that makes Arith much more usable.

It does provide extra information: it indicates what the precise semantic of the map is in a way that arith just kind of randomly selects a specific subset of the tensor datatype and says that is what is supported. The next person who comes along asking why arith doesn’t support broadcasting, encoding, or all of the other random things that tensor models gets told “make a new op” vs “well we decided 5 years ago that arith should do this one thing but we don’t know why and it isn’t really consistent.”

Would this really be unreadable?

%0 = tensor.elementwise_map "arith.addi" %arg1, %arg2 : tensor<1x2xf32>

So no folders, no canonicalization, and no verifiers of any kind?

(edited with an example)

Arith dialect is already quite low-level, it has addf/addi separation, fasmath flags for float ops (and we probably also want nuw, nsw for integer ops). Extending it to support any type will be just awkward. Using llvm dialect for this is also not ideal as llvm is not the only possible target as it was mentioned previously in thread. IMO, ‘arith on any type’ belongs to frontend dialects realm, but it will be hard to find common ground across projects here, I’m afraid.

Don’t have strong opinion about elementwise tensor support specifically as we are not using it.

2 Likes

Yeah, this is correct and is representative of almost all of the code both implementing it and lowering to and from it. I’m not arguing that a more frontend-level, very abstract dialect doesn’t have a use, but I’m not going to spend time trying to reach a consensus on what such a thing is in the absence of much stronger requirements or priors. In my experience, stronger constraints, even at the frontend, almost always produce a more robust system, and even if we had such a dialect, I would probably opt to do something project specific simply because I could constrain and mold it to the need (versus relying on something hyper-abstract and ill-defined).

The one thing I can offer is that the level of abstraction being asked for is exactly where the tensor dialect sits. You can’t go lower than this point and still represent a tensor abstractly in any way. We can do any number of things at or above that level.

In my opinion, the path forward for arith and math is to finish aligning them with the llvm dialect and ultimately evaluate whether they are redundant with it and whether a better factoring can exist at that level.

(none of this is to say that we have to land a change “tomorrow” that breaks people and existing users without a viable path forward – but I agree with Renato: we have to resolve this finally)

The goal here wasn’t to use it as a high level dialect. It would still be a low level dialect but there are multiple cases where using arith for low level representation with custom types is useful. For example some hardware have custom floating point types that currently cannot be used, same problem with wmma kind of type where we have to use some hacky op to represent elementwise.

In Triton arith is not meant to be used as a low level dialect. It sounds like using tensors at the low level is unconventional which is what makes triton usage look odds to others.

For elementwise the op doesn’t work on a subset of the type. The argument about broadcast also applies to vector, yet we wouldn’t want to support vector broadcasts or disallow vectors.

That sounds reasonable to me. If arith/math are not meant to exist in the current form then all of this doesn’t matter. In the meantime it would be great if we could not prematurely break Triton flow.

Chatted a bit with @ThomasRaoux offline, my take right now is that conceptually Triton is closer to use multi-dimensional vector than tensor, and this is why handling the arithmetic uniformly is useful: it is really modeling HW operation on register files with a statically known size (there is no unknown dimensions involved).

The only reason to use the tensor type instead of vector in Triton (as this level of abstraction) that I can think of is that tensor supports an “encoding” attribute.

The kind of swizzling of data happening on GPUs makes it very convenient to decouple the layout of the data in the register file (or the shared memory) and the operation to execute.
Maybe one direction to look for is if we could add an optional layout attribute on the vector type in the future.

4 Likes

I was similarly looking at something like this recently and came to the same conclusion for my case. tensor has a lot of abstract and undefined corners that make it ill-suited for this level of programming. Of course, at some level it is just a type and can be used in various ways, but it definitely “wants” to be something quite different.

This could be interesting. Do you have an specific example? AVX-512 has a swizzling field in many instructions but we never found a need to model that explicitly at op level because the backend was able to do pretty good job here. The permutation maps that we have in some ops + shuffle operations have been enough so far but maybe that’s not the case here.

The difference is that with a system like AVX, you’ll provide LLVM with the “tile view” of the operation, because you have a register to hold the data and LLVM can reshuffle the data at will.
In comparison on GPUs, LLVM does not have this possibility, it gets only the single-thread view and the “tile” is actually managed implicitly by a group of “threads”. The data layout is already materialized and fixed and not expressed as a “vector” in LLVM.
So basically the equivalent of an AVX or AMX operation in LLVM is a tile-level operation in Triton which requires the cooperation of multiple “threads” which are expressed individually in LLVM (SIMT model…).

For what’s involved in practice, I think the best example is Cutlass/Cute concepts which illustrate the way data swizzling is decoupled from the thread mapping, maybe this video is good intro: https://www.youtube.com/watch?v=PWWOGrLZtZg

2 Likes

@harsh-nod ^

I believe that the proposal will also help reduce complexity in sparse tensor dialect: Sparsification is largely built around linalg ops, we need to rewrite the element-wise arith/math ops on sparse tensors to linalg ops before sparsification anyway.

1 Like

strong +1 to this

+2