RFC: Split Elementwise trait, or create AbstractElementwise?

While working with Tosa transformations, we’ve come across situations in which we’d like to know whether an op is element-wise, but can’t use the existing Elementwise trait, because we’re still at a level in which tensor dimensions are not all the same, thanks to implicit broadcasting. We may even have scalars that are later expanded to tensors, or unranked tensors whose shape is inferred.

We’d like to create an AbstractElementwise trait, which indicates that an op works in an element-wise manner upon its eventual operands. I prototyped a couple of ways to do it:

  • Add a simple trait in mlir/IR/OpBase.td and mlir/Dialect/Traits.h, as a non-core but generally available trait. As a variant, we could make it Tosa-specific, but it seems it would be useful elsewhere.

  • Split the Elementwise trait into AbstractElementwise and CompatibleShapes, with the latter acquiring the verification function. There are some static_asserts in mlir/IR/OpDefinition.h that I could use some help with, but Elementwise should continue to behave as it does now.

Let me know which way you prefer, or if there is a better one.


Thanks for posting! Would you happen to have some IR of such cases that we could look at for these cases?

https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir contains some examples:

  • ranked vs unranked:
    %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>

  • broadcast dimension:
    %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>

  • shaped vs scalar:
    %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>

A couple of the tests are named test_binary_scalar_f32 and test_binary_broadcast_f32.

A third way: make the question of whether Elementwise insists on matching shapes a characteristic of the dialect. Ie, add something like allowsBroadcasts() to Dialect, defaulting to false, and override it for TosaDialect, and then skip verifyElementwise() if op->getDialect()->allowsBroadcasts() is true.

Sorry to pop this back up to the top question, but could you elaborate more on the types of transforms you are doing which would benefit from such an indication?

There is a lot of history with respect to broadcasting decisions and I’ve generally found that it is important to dig a bit deeper when talking about needs here.

There are two areas that to my knowledge are underspecified in the tosa dialect (not the specification but the implementation in MLIR):

  • What level of shape polymorphism is allowed for unranked tensors. Afaik, these purely existence as a convenience for interop with contexts for which ranks are not yet known so that MLIR based shape inference can specialize them.
  • Which unknown dimensions of a tensor can be considered to be polymorphic over both a quantity or an expansion dim (1-dim paired with a non 1-dim in the other operand).

For the second, any further information needs to be at a dimension level, not just a bit on the op.

In general, I believe there is more design work needed on this dialect to comprehensively support dynamic shapes and I’m a bit hesitant to add partial fixes with such a trait (without understanding more of the motivation - which I’m not quite grasping).

We’re partitioning a function in a way that gathers a convolution op and its adjacent element-wise ops, anticipating future fusion. I will admit that I’m not familiar with the details after that point, but if we assume a correct program, won’t the shapes have to conform eventually?

For our purposes, we just want to know this characteristic of the op, but it’s too early to have the shapes expanded. I suppose we could also go with a big isa<>, but that’s what ElementwiseMappable was created to avoid.

I came across this thread rather late, not having caught the TOSA reference sooner. Interesting discussion topic indeed.

We’re completely fine with the dialect being enhanced as suitable - our own focus has been keeping the dialect functionally aligned to spec, with lesser emphasis on keeping up its integration with compiler artifacts. In the latter, we welcome inputs enabling the right constructs.

Shape polymorphism rules on unranked tensors is a tricky topic. I see that the tosa-make-broadcastable pass (which was contributed with the dialect, and expresses Numpy-aligned rank broadcasting) makes no attempt to process the shaped vs scalar example, but will readily perform rank broadcasting if the problem is reformulated to a simpler unknown dimsize one.

It seems there are two questions here - a) is the op elementwise, and b) broadcasting forms in relation to unranked tensors.

I believe that ElementwiseMappable was created for a very narrow purpose: to note that an op could operate on either a scalar or a vector (incl tensor) in a strictly element by element wise way. Iirc, even that degree of freedom was not without controversy. Bringing broadcasting and shape polymorphism in to it is quite a bit further still.

What you may be reaching for here is an op interface (possibly tosa specific) so that you can abstractly reason about an entire set of ops.

What I’m not clear on is exactly what this interface would be named or do. Often when coming to that point, it is useful to start with the switchy version or have a helper that is pretty verbose but does what needs to be done. Then look at that and ask whether it should be an interface. Doing it this way has a way of making it concrete. Also, there are a lot of times where upon working with it, your viewpoint refines. This is hard to do if prematurely jumping to a trait or interface for something that is still evolving.

My two cents.

1 Like

If you need to reason about this kind of behavior, have you considered something like Linalg. There are lowerings from TOSA to Linalg, and there you dont need traits, but an analysis of the operation itself would probably give you the information you are looking for (sorry for being vague, but I am not clear about the transformation you are looking to do at TOSA level to comment more)

I apologize for coming to this late as well. I am working with @pcf.

What we really want to reason about is how a set of ops produce results. For many(, many) ops each result element is computed independently, based on the corresponding (index-aligned) primary input element. Thus for unary ops it aligns with what I understand as ‘Elementwise’. The wrinkle is for binary ops when a secondary input is also index-aligned excepting broadcast dimensions.

I would have preferred 2 attributes/traits: Elementwise for general index-aligned ops (accepting broadcasting inputs), and InputBroadcastable for that special case. But since we cannot overload Elementwise (it later requires equivalent dimensionality). And thus we proposed this alternative.

Hi folks. Just echoing what Stella said. I don’t think we need to generalize Elementwise here (I’m the one that added that trait). Either a big isa<> or a TOSA-local “Elementwise” trait is all that is needed for now, given the subtleties of defining the broadcasting behavior. Those subtleties make a more abstract elementwise trait not as useful for a wide variety of transformations.