Our std elementwise ops like std.addf accept tensors as operands. However, this fact has not been used at all AFAIK, so it needs a little love to make it great.
I have recently started writing a conversion to linalg from std elementwise ops on tensors, and it is working really nicely (one small MatchAnyOpTypeTag pattern can convert all ~30 ops to linalg): https://reviews.llvm.org/D90354
So I’d like to formalize a couple things to make transformations like these easier/possible to write.
Currently, elementwise ops on tensors don’t have a clearly specified behavior in the case of a dynamic shape mismatch. E.g. consider
%0 = addf %lhs, %rhs : tensor<?xf32> // Generic form: %0 = "std.addf"(%lhs, %rhs) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
Due to dynamic shapes, at runtime %lhs could have type tensor<2xf32> and %rhs could have type tensor<7xf32>. (this is not an issue for vectors since they are statically shaped).
Proposal: We specify that std elementwise ops on vectors have undefined behavior in the case of dynamic shape mismatches of the operands. This is consistent with being able to lower them to linalg and other code generation systems.
Rejected alternative: in case of dynamic shape mismatch, the ops must safely terminate the program. This is problematic, because then we need to reify error handling logic before lowering them to code generation systems.
Rejected alternative: Disallow tensor types in std elementwise ops. While it is not hard to directly create a linalg op with a std elementwise op in its payload, this does incorporate an early decision to commit to linalg. With std elementwise ops on tensors, bidirectional conversions between the elementwise ops of HLO and TOSA becomes pretty easy (possibly even removing the need for some of the ops in those respective dialects).
A good amount of thought went into the design of our std elementwise ops, and I can’t think of a compelling reason why any aspect of their design rationale (such as splitting addf from addi) would not apply equally to the tensor case (especially as it already applies to the vector case!).
In my current std.elementwise → linalg patch, I need to have a big nasty “isa” check enumerating all the elementwise ops that are handled (see function isStdElementwiseOpOnRankedTensors).
This is problematic, since it is difficult to keep up to date, and also the “elementwise” property is useful for other analyses/transformations.
Proposal: Add an
Elementwise trait with the following semantics for any op that has this trait:
- If any result is a tensor, there must be at least one operand which is a tensor.
- If any operand is a tensor, then there must be at least one result which is a tensor.
- The static types of all tensor operands and results must have the same shape (element type can vary, such as std.select predicate having i1 element type).
- The dynamic shapes of all tensor operands and results must be the same, otherwise the op has undefined behavior.
- (“systematic scalarizability” property) All
Elementwiseops that have operands/results of tensor types, must also be valid with the same operands/results changed to their respective element types. This creates the “scalarized” form of the op.
- The semantics of the op on tensors must be the same as applying the scalarized op at each corresponding element of the tensor operands/results in parallel.
The above also would apply to vectors as well (need to wordsmith the definition to include this), and can be used to derive transformations such as a “scalarization” pass or implementing interfaces such as
Some care has been taken in this definition. For example, a std.select op with scalar predicate and tensor true/false operands still abides by the
Elementwise trait, because the semantics are still the same as applying the scalarized op at each corresponding element. Also, we don’t need to specify which operands/results are allowed to be tensors: the “systematic scalarizability” property is all we need.
Additionally, any op (not just std ops) can implement Elementwise, and the above convert-std-to-linalg pass can convert it to linalg (in fact, we can rename the pass “elementwise-to-linalg”). For example, a downstream dialect might want to have an “add” op that lowers to addf or addi depending on the type; such an op can be easily adapted to be
Alternative considered: Have Elementwise be an OpInterface (instead of a trait) with a method “createScalarElementwiseComputation” that replaces the “systematic scalarizability” property above with a totally customizable hook. This could work as well, but seems bulkier for little added benefit.
The added benefit would be that it would allow e.g. a tensor “maxf” op to implement createScalarElementwiseComputation that creates scalar std.cmpf/std.select in the method. However, it seems preferable to adopt the proposal here, and instead legalize the tensor “maxf” op into tensor “std.cmpf” + tensor “std.select”, which then both implement Elementwise and can be trivially fused. Alternatively, the tensor “maxf” op could also adopt Elementwise by simply defining its behavior on scalars, which seems very little effort.