RFC: remove arith/math ops on tensors

I propose to remove the automatic elementwise extension of arith and math ops to tensors. Equivalent operations can be represented as named Linalg ops, or TOSA ops if/when needed. Having these in arith needlessly complicates the dialect itself and its lowering schemes. Conceptually, these are both more complicated than elementwise application to vectors and insufficiently expressive as Linalg generics can also represent broadcast/permutation that could be desirable for tensor-level operations.

4 Likes

Why not move the tensor functionality of these to tensor dialect?

Significantly more work with no clear use case. We already have linalg.add, linalg.mul, etc. as well as tosa.add, tosa.mul, etc. Even more duplication doesn’t look warranted.

So adding on vectors would remain where it is while adding tensors would be moved? What complicates the dialect and lowering schemes for tensor but not vector?

Vector addition doesn’t need to be bufferized or converted to another dialect to become lowerable to LLVM or SPIR-V. Vector types only have static shapes so there is always the naive “unroll-to-scalar” approach to lowering that works systematically for vectors, and doesn’t work for tensors that can have dynamic shape, rank and sparse layout. We can also have tensors-of-vectors (yikes), but not vector-of-vectors.

Tensors are just a different abstraction, although they are mechanically similar to vectors, and we do different things on them. Having to care about another abstraction in the arith dialect adds complexity.

2 Likes

Note that that’s not the case for scalable vectors - scalable dimensions are very similar to dynamic dimensions in tensors. We have actually been replacing plain “unrolling” with scf.fors for this very reason.

Strong -1 on my side. We do use this in triton and replacing it with linalg is not possible as linalg uses destination passing style.
I don’t fully understand how this adds complexity in the dialect. I see how this is hard for the bufferization but bufferization doesn’t have to handle arith dialect if it doesn’t want to.
How much code would that remove from the arith dialect itself?

1 Like

I am not totally convinced here, mostly because the alternatives are pretty load bearing right now: TOSA and Linalg aren’t really designed at their core to provide what arith did provide on tensor (linalg is a pretty big hammer when all you need is an add!).

On the other hand we’re getting closer to a state where arith is redundant with LLVM (I should follow up on [RFC] arithmetic vs llvm dialect …), and only multi-dimensional vectors would remain as a difference here.

So I agree that the current state is mixing two level of abstractions in the same dialect, and it’s pretty undesirable, but without a simple “tensor arith” dialect, I am a bit concerned about the impact on the ecosystem (do we expect users like Triton to take a dependency on linalg just to express basic arithmetic on tensors?).

This isn’t as much about the dialect as it is about any user of the dialect (even though I’d have to check canonicalization and folders). It makes any dialect conversion more complex because you can’t write code that assumes all you get as input is scalar. Any notion of “legality” becomes dynamic, you can’t just say “arith is legal”, or “I support arith”, or “my op conversion can handle arith.add”: it always need to check if is it tensor domain or not.

1 Like

Right, I wonder if this is also true for N-D vectors but I see your point. I still think there are no reasonable alternatives and that would cause projects like Triton (but I expect triton to not be the only one) to fork this dialect downstream.

Minor note: I think I came across recently that torch-mlir is still using this, but it is at the level of implementation detail vs load bearing, I think. Would need to look more before removing.

FWIW, I agree with the proposal and I think Alex has made the argument very well. That said, I can only comment from an architecture/layering perspective - I’m not invested in the arith/math dialects and don’t have practical knowledge about what impact the change would mean for existing clients.

-Chris

I agree as well – we made a mistake when we did it this way and it is an unfortunate coupling of concepts.

If this is really being relied on in a load bearing way, it would be good to figure out how to meet the need in a better way, imo.

(We can fix the torch-mlir usage. Just not sure when/how urgent)

I’m not sure I understand why using arith ops with tensor types cannot be a legitimate use case. It sounds like people assume those kind of operations are meant to represent graph level tensor operations but it doesn’t have to be and that’s not how Triton uses it.

Triton uses tensors to represents immutable data that are later distributed to registers or memory and being able to represent operations working on scalar or tensor with the same dialect is critical to keep things simple. I understand this different than what linalg or torch-mlir does but I don’t understand why this a model we would not want to support?

In the past I wished things were actually going to move in the other direction where elementwise arith dialect would work on any type rather than being limited to vector/tensor/scalar. This was something we ran into when plumbing wmma operations for vulkan where we wanted to be able to apply elementwise arith operations to the gpu wmma type. Another idea that came up recently was to be able to use this dialect with parametric vector types (that would potentially live downstream) as described by Modular engineer at LLVM conference.

I think that my thinking on this derives from the upstream tensor type meaning many things in practice and depending on context, and the semantics it is defined for in arith and math is a very narrow subset of that – so much so to seem somewhat arbitrary. Why not also support broadcasting (and then of what type), encodings, etc become natural next questions. Then you get into very non trivial and disjoint lowerings…

I’ve been pegged on that perspective since we defined it this way, and it has always seemed ad-hoc and under defined. It is possible that the situation has evolved practically, though, I guess.

I’m not going to claim to understand the ways that scalable vectors changes vector types, but leaving that out, for that type, the semantics are well defined.

I completely buy that there is utility in having ops to do what Triton is doing on tensors. I’m just not sure they are these ops…

I don’t think folks are against this as a use-case, we’re rather considering if this use-case should be conflated together with the “low-level” LLVM-style Arith (which for my part is something I would rather merge as the “LLVM Arith dialect”).

Can you expand on this actually? I don’t quite get why operating on scalar with the same op really helps you? (also other representation may use tensor with rank 0 instead of scalar to preserve uniformity of the IR on tensor).

There is a tradeoff in complexity (as I mentioned in my previous answer to you where I provided the cons, so if you can detail the opposite view it’d be nice).

because it allows writing transformations that works on some ops independently of the type (random example for triton). This is the same reason why you wouldn’t want a separate dialect for vector arithmetic. If we end up with a duplicated arith dialect just for tensors then you would need to templatise the patterns to work with both kind of ops.

I’m especially concerned about having to stop using arith dialect because other dialects like SCF depend on arith for transformations.

1 Like

So if I understand your point the problem is that the semantic is not clearly defined? For elementwise operations the semantic seems fairly straight forward right? It returns a new tensor with same shape where each element is computed like the equivalent scalar op would. And anything beyond elementwise should live in tensor dialect.

I completely buy that there is utility in having ops to do what Triton is doing on tensors. I’m just not sure they are these ops…

What makes you think that?

… And represent disjoint lowerings paths to llvm. Arith is currently playing double duty being both a low level, llvm aligned dialect and a mathy dialect that operates on abstract data types that have no defined representation (or relationship with memory). I don’t think anyone is saying that these abstract things shouldn’t be represented. Just that maybe these are not the same things.

It is just a consequence of my argument and rhetorical: if these cases are not the same and both are valid, it would imply that we are missing something if we are forcing them into the same representation. I didn’t mean it as a pronouncement.

Does the comment upthread have any merit about linalg.generic being able to express this? I agree that is a very large and specific hammer. But what if we had something like tensor.map_elementwise which was not DPS and was represented more like some of the sugared reducers in high level tensor opsets (ie. Yields the result of a binary op valid for its element type and has a clever printer/parser for compact usage)? That would get us out of the weird zone where you have to swallow linalg/DPS or TOSA (which could do it but is quite far in terms of applicability and consistency with the existing low level ops).

1 Like

I agree with that, the double use of Arith is what seems to be causing the disconnect. That being said I think we should be concerned about removing one of the use cases of Arith without having a concrete alternative. @mehdi_amini has been suggesting using directly LLVM dialect for the low level part, I’m not sure if that address the concerns? I don’t have a clear vision of how it would look like.

It sounds like we agree that neither linalg or TOSA are a practical alternative. I’m happy to discuss more what kind of dialect could be a solution although I don’t know if this is the right next step. I would prefer not having to move more of my work downstream.

1 Like