[RFC] Primitive Ops: add MapOp, ReductionOp, TransposeOp, BroadcastOp to Linalg
Motivation
linalg.generic
is a powerful Swiss knife of an operation. It can model elementwise operations, reductions, static broadcasts and reshapes, transposes, convolutions, matmuls and more. There are several problems with it and I think they can be addressed relatively easily.
Layering
There are examples of compilation pipelines in TF and IREE that go from a front-end dialect, e.g. mHLO, directly to linalg.generic
and then perform tiling, fusion and other related transforms like tree reductions. In order to do so, linalg.generic
ops are pattern-matched to understand what they actually model. Is it a reduction? Is an elementwise op? Is it an elementwise op that also broadcasts one of the inputs? Is it a reduction that had multiple ops fused into it?
We are losing the structure and then try to recover it with pattern-matching. It is unnecessary and indicates a missing layer of ops.
Readability
The IR that we generate ideally should be readable and concise not only for “initiated”. Here is an example of a simple add
op on tensors that was lowered to linalg.generic
.
%add = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%lhs, %rhs : tensor<10xf32>, tensor<10xf32>)
outs(%out : tensor<10xf32>) {
^bb0(%l: f32, %r: f32, %o: f32):
%0 = arith.addf %l, %r: f32
linalg.yield %0 : f32
} -> tensor<10xf32>
and this is an example of the same op conveted to the missing linalg.map
.
%add = linalg.map ins(%lhs:tensor<10xf32>, %rhs:tensor<10xf32>)
outs(%out:tensor<10xf32>) optional-attrs
(%l: f32, %r: f32) {
%0 = arith.addf %l, %r: f32
linalg.yield %0: f32
}
I argue that having simpler operations that model elementwise
, reduction
, transpose
, static broadcast
would improve readability and reduce the need for pattern matching in transformation passes, for example, when clustering operations to perform fusion.
Therefore, it might be a good time to reconsider whether the front-end dialects should be converted to linalg.generic
or to a primitive set of ops, possibly within Linalg dialect that can be trivially lowered to linalg.generic
.
Linalg Named Ops
All of the above is suspiciously similar to Linalg Named Ops, a way to define structured ops that can be converted to linalg.generic
using YAML. Unfortunately, it requires defining a lot of operations for every possible rank and combination of what gets reduced, transposed, etc. For example, if we want to define an op for reduction, then not one, but many ops have to be created: reduction_1d
, reduction_2d
, column_reduction_2d
, row_reduction_2d
, …
Design
Add MapOp
, IotaOp
, ReductionOp
, TransposeOp
and BroadcastOp
to LinalgStructuredOps.td. All of the ops will implement LinalgOpInterface
and DestinationStyleOpInterface
, which will allow for tiling/fusion/lowering to loops and bufferization.
These ops can be seen as complementary to Linalg Named Ops, but defined in TableGen and C++ instead of being generated out of the YAML config.
Map
This is an n-ary operation where the shapes of the inputs and the output are the same. It does not have any implicit broadcasting behaviour.
%add = linalg.map ins(%lhs:tensor<10xf32>, %rhs:tensor<10xf32>)
outs(%out:tensor<10xf32>) optional-attrs
(%l: f32, %r: f32) {
%0 = arith.addf %l, %r: f32
linalg.yield %0: f32
}
linalg.map
has a region with block arguments corresponding to every input. Note, that it does not have a block argument for the output unlike in linalg.generic
.
Bike-shed: Should it be called linalg.map
, linalg.cwise
, linalg.pointwise
or linalg.elementwise
?
Iota
Iota is an operation that does not have any inputs, but it uses the values of the induction variables to populate the output.
%iota = linalg.iota outs(%output:tensor<8x16xf32>) dim = [0] optional-attrs
Bike-shed: Do we really need this operation? Can’t we just use linalg.map
with linalg.index
inside to model it?
%iota = linalg.elementwise outs(%out:tensor<8x16xf32>) optional-attrs () {
%0 = linalg.index : 0
linalg.yield %0: f32
}
Transpose
Transpose operation specifies the permutation of the input dimensions.
%transpose = linalg.transpose
ins(%input:tensor<16x64xf32>)
outs(%output:tensor<64x16xf32>)
permutation = [1, 0] optional-attrs
Reduction
Reduction operation specifies what dimensions will be reduced and the body region that contains the combiner.
%sum = linalg.reduction
ins(%input:tensor<16x64xf32>)
outs(%output:tensor<16xf32>) dimensions = [1] optional-attrs
(%in: f32, %out: f32) {
%0 = arith.addf %in, %out: f32
linalg.yield %0: f32
}
Broadcast
This is a static broadcast operation, i.e. there is no ambiguity in compile-time what size-1 dimensions should be expanded.
%bcast = linalg.broadcast
ins(%input:tensor<16xf32>)
outs(%output:tensor<16x64xf32>) dimensions = [0] optional-attrs
Question: Do we need size-1 expansion in this version of broadcasting?
Dynamic broadcast operation cannot be modeled with linalg.generic
and it is too specific to TensorFlow, that’s why it should not be a part of Linalg.
Reshape
Reshapes are already modeled with tensor.expand_shape
, tensor.collapse_shape
or a combination of a tensor.collapse_shape
that reshapes to 1D and then tensor.expand_shape
that expands to the target shape.
There are dynamic reshapes that cannot be modeled with linalg.generic
, e.g. reshaping tensor<?x?xf32>
to tensor<?x?x?xf32>
. It cannot be implemented LinalgOpInterface
but can be converted to tensor.reshape
.
It is an interesting question whether we need a dst-style linalg.reshape
op at all or we can live with tensor.[expand_shape, collapse_shape, reshape]
. Usually, linalg.reshape
does not allocate during bufferization and just becomes a memref descriptor operation, so I am not sure how important it is to have it in dst-style.
%reshape = linalg.reshape
ins(%input:tensor<?x?xf32>)
outs(%output:tensor<?x?x?xf32>) optional-attrs
Also it is quite different from the ops above w.r.t. to tiling and fusion, since reshapes are only tileable to size-1 tiles, i.e. to point level.
Matmul, Fill and Convolutions
These ops are already defined in Linalg and HLO gets converted to them instead of linalg.generic
. They are already readable and useful.
There are many operations for convolutions, i.e. Conv1DNwcWcfOp
, Conv2DNhwcHwcfOp
, Conv3DNdhwcDhwcfOp
, DepthwiseConv1DNwcWcmOp
and more. It might be possible to find a smaller set of ops to cover them all.