The arithmetic dialect supports point-wise operations on tensors. However, arithmetic operations with tensor operands are not bufferizable. If the MLIR byte code contains arithmetic operations with tensor operands during the one-shot-bufferization pass then the one-shot-bufferization pass will fail.
Proposal
We would like to contribute a pass that replaces arith ops on tensors with arith ops on scalars embedded in the linalg::GenericOp. If this pass is run before the on-shot-bufferization pass then the semantics of the arith ops on tensors are lowered to linalg ops which allow for the bufferization of tensor operands.
Prototype
We’ve worked on a prototype that implements a functionality similar to that of TosaToLinalg conversion, however the prototype differs on some key aspects:
Instead of translating the TOSA dialect, the prototype translates only a subset of the arith dialect to linalg.
Whenever dynamic dimensions are encountered in the arith operands a comparison and an assertion are emitted to assert at run time that the dynamic dimensions are equal.
Operations between tensors must be of the same rank.
The current prototype could be refactored to share common code with TosaToLinalg but some direction is needed to place common infrastructure.
Hi - not sure if this fits the bill but have you looked at the lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp pass? I believe that it converts the strict subset of things that we believed we knew how to handle at the time that the ElementwiseMappable trait was defined.
I think this is exactly what I was hoping to implement. I’m wondering though, what are the semantics for tensors with dynamic dimensions when dimensions might differ? When I translate an example using -convert-elementwise-to-linalg:
Afk, but I believe the semantics of ElementwiseMappable are that different dimensions is UB (same as linalg). I remember there being a pass (or maybe just a discussion of a pass) which would add shape asserts at the linalg level. Most frontends are adding those at the top level, though (since in all generality, that is where you can reason about such things).
In hlo land, we have broadcasting operations that support operands with different shapes (to some degree) and those are modeled in chlo, for example the BroadcastAddOp. When we lower these to mhlo which requires operands to have matching shape, shape constraints are inserted.
You could do something similar for the arith dialect but according to the documentation for the ElementwiseMappable trait (and as @stellaraccident mentioned), broadcasting is not allowed or supported. arith really just supports elementwise in the most elemental meaning and everything else is undefined.