[RFC]: Extend Linalg elemwise named ops semantics
Written in collaboration with @rengolin, @MaheshRavishankar, @banach-space
Context and Background
In Oct 2024, Renato put out [RFC]: Op Explosion in Linalg highlighting the problem – “We have too many similar operations in Linalg (conv, matmul) that cannot be decomposed because they were defined in OpDSL. We discussed about this in the past and the end result is to move it out of OpDSL (Python, YAML), into ODS (table-gen, C++) so that we can common up implementation details”.
The proposal identified – matmul, contraction, convolution, and element-wise, amongst others – as areas needing attention. Subsequent to Renato’s RFC, Rolf put out [RFC]: introduce linalg.contract. More threads on same topic can found here - [RFC] : Linalg Operation Tree, and move matmul to ODS.
Problem Description
Linalg operations such as linalg.add or linalg.exp currently do not support broadcast, transpose (projected permutation in general) or type conversion. User-defined affine maps cannot be supplied (unlike linalg.generic and now linalg.matmul), and the current semantics of element-wise requires same type (i.e. shape and element type) for all operands and results. While the following is currently a valid IR –
%add = linalg.add
ins(%x, %y : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
outs(%z : tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
The following which tries to express transpose on the first operand before the operation is currently not a valid IR –
%add_transpose_a = linalg.add
ins(%x, %y : tensor<8x4x16xf16>, tensor<4x8x16xf32>)
outs(%z: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
Any required transpose, broadcast, type conversion etc. therefore must be done prior to calling this op or afterwards.
What is the benefit of extending semantics of elemenwise ops and ops such as linalg.add? Instead of having in the graph IR several ops performing – input reshape and type conversions; the actual computation (e.g. linalg.add); conversion of result to the desired output type – in a series of linalg.generics and other ops, and then relying on the folding optimizations to generate a compact low entropy overall IR, this can be expressed compactly if we define a new linalg.elemwise with the semantics that overcomes the above limitation. Also, the linalg.elemwise will be defined in ODS (in the spirit of ‘move out of OpDSL’).
Proposal: Introduce a new linalg.elemwise
This RFC proposes a new linalg.elemwise defined under ODS with semantics that overcomes above limitations. We seek feedback, although to add value to this proposal, a prototype implementation of what is proposed here follows this RFC.
Syntax
elemwise-op ::= `linalg.elemwise`
`func_type` `=` func-op
(`comp_type` `=` comp-type)?
(`indexing_maps` `=` `[` affine-map+ `]`)?
`ins( `$A`,` $B `:` tensor-or-memref-type `,`
tensor-or-memref-type `)`
`outs( `$C `:` tensor-or-memref-type `)`
(`->` tensor-type)?
func-op ::= `#linalg.elemwise_fn` `<` func-op-kind `>`
func-op-kind ::= `exp` | .. | `erf` | `add` | .. | `powf` | `select`
comp-type ::= `#linalg.comp_type` `<` comp-kind `>`
comp-kind ::= `i1` | `i8` | .. | `F64`
The attribute func_type describes the operation type (e.g. add, exp). Attribute comp_type type defines the type on which the elementwise operation is to be performed. If all types match there is no need to specify this attribute. Affine-map for operands and result must be only projected permutations with no zero constants.
Verifier
The verifier checks -
- The
func_typeis a valid unary, binary, or ternary operation. - The semantics of
comp_type(extension/truncation/type-conversion) probably needs some discussions here, so keeping it open for now. - The
indexing_mapsattribute of the input operands must consist of affine_maps which are projected permutations. The indexing_map of theoutputmust be identity. - The number of input-operands must match the arity (unary, binary, or ternary) inferred from the
func_type. - Some more checks are done via already available
verifyStructuredOpInterface
Inference and Deduction
- When a user-defined indexing_map is not provided, identity map is inferred for all operands. The default indexing maps are N identity-maps. ‘N’ depends on the
arityof the elementwise op. The number of dims is inferred from rank of the output type. In the case of default indexing map, the input and output shapes must all match. - For element-wise
iterator-typeis always inferred as all ‘parallel’. Iterator-type is needed for constructing the structured op that lowers to linalg.generic. The number of dims of the iterator-type is inferred as from the indexing_map dims. - The
arityis inferred fromfunc_type.
Einsum Notation
Elementwise binary op –
C[I^C] = A[I^A] ⊕ B[I^B]
where I^A$, I^B, and I^C are multi-indices, i.e. sequences/ordered sets of dimension identifiers (meant to range over index ranges) corresponding to the co-domains of the respective affine_maps. ⊕ is the selected kind of element-wise operation.
Similaryly, elementwise unary op –
C[I^C] = λ A[I^A]
where λ is selected unary operator.
Worked out Example
Below we show a working example IR:
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#transpose = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @test(%arg0: tensor<4x16x8xf32>,
%arg1: tensor<4x16xf32>,
%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
%e = tensor.empty() : tensor<4x16x8xf32>
// unary op (exp). no user-defined map.
%exp = linalg.elemwise
func_type=#linalg.elemwise_fn<exp>
ins(%arg0 : tensor<4x16x8xf32>) outs(%e: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
// binary op, user defined map
%add = linalg.elemwise
func_type=#linalg.elemwise_fn<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
return %add : tensor<4x8x16xf32>
}
Lowering of Example to linalg.generic
Above lowers to the following linalg.generic (mlir-opt --linalg-generalize-named-ops test.mlir):
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
module {
func.func @test(%arg0: tensor<4x16x8xf32>, %arg1: tensor<4x16xf32>, %arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
%0 = tensor.empty() : tensor<4x16x8xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x16x8xf32>) outs(%0 : tensor<4x16x8xf32>) {
^bb0(%in: f32, %out: f32):
%3 = math.exp %in : f32
linalg.yield %3 : f32
} -> tensor<4x16x8xf32>
%2 = linalg.generic {indexing_maps = [#map1, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) outs(%arg2 : tensor<4x8x16xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = arith.addf %in, %in_0 : f32
linalg.yield %3 : f32
} -> tensor<4x8x16xf32>
return %2 : tensor<4x8x16xf32>
}
}
Actions
- First PR : I have an implementation and some tests of things working. E.g. the above test example and lowering is from my working diff. I will put it up in couple of days. It does not have lots of tests or covers all bases as some implementations are likely to change based on inputs here.
- Gather Feedback : get feedback on this RFC and revise implementation.
- Discuss with community time-line and order of the next steps - e.g. adding transformations that leverages this new linalg.elemwise.