I’m working on upstream the lowering in IREE to MHLO. When lowering mhlo.reduce to Linalg, @herhut and I had a discussion. After prototyping it, I hit some issues. I’m wondering if they are issues, and if they should be fixed.
The goal is to lower a reduce op to Linalg generic, e.g.,
func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = mhlo.add %arg3, %arg4 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
to
...
%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xi32>) outs(%2 : tensor<?xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%4 = addi %arg2, %arg3 : i32
linalg.yield %4 : i32
} -> tensor<?xi32>
...
The generic op will clone the region of reduce op, rewrite block args in scalar form, and lower mhlo ops (within the region) to std ops.
This is a big step, so we tried to break it into pieces.
- Create a
linalg.generic
op and clone the region. In this step, the body starts with ascalar_to_0d_tensor
operation that turns the scalar into a 0d tensor, then the normal HLO follows.
%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xi32>) outs(%2 : tensor<?xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%4 = scalar_to_0d_tensor %arg2 : i32 to tensor<i32>
%5 = scalar_to_0d_tensor %arg3 : i32 to tensor<i32>
%6 = mhlo.add %4, %5 : tensor<i32>
%7 = tensor.extract %6[] : tensor<i32>
linalg.yield %7 : i32
} -> tensor<?xi32>
- Apply the mhlo to scalar standard lowering to the body of these regions. If all operations could be converted, we get
extract(scalar_to_0d_tensor(x))
pairs that canonicalize tox
.
The issue in step one is that we are lacking a tensor_reshape
like operation. I tried tensor::CastOp
and it complains about 'tensor<1xf32>' and result type 'tensor<f32>' are cast incompatible
. One workaround is to cast it twice, but I don’t think this is the way to go.
%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xi32>) outs(%2 : tensor<?xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%4 = tensor.from_elements %arg2 : tensor<1xi32>
%5 = tensor.cast %4 : tensor<1xi32> to tensor<*xi32>
%6 = tensor.cast %5 : tensor<*xi32> to tensor<i32>
%7 = tensor.from_elements %arg3 : tensor<1xi32>
%8 = tensor.cast %7 : tensor<1xi32> to tensor<*xi32>
%9 = tensor.cast %8 : tensor<*xi32> to tensor<i32>
%10 = mhlo.add %6, %9 : tensor<i32>
%11 = tensor.extract %10[] : tensor<i32>
linalg.yield %11 : i32
} -> tensor<?xi32>
The issue in step two is the order in which rewrites happen. There could be many mhlo ops in a function. I’m not sure if there is a guaranteed order. If configuring legality by saying the mhlo is legal inside a generic op, then this would work if the framework guarantees a pre-order rewrite traversal. Otherwise, the system could rewrite the “inner” hlos to Linalg generic first. Because they would still be inside another mhlo and not yet in a Linalg generic.
In the end, I landed the commit with the original pass – convert the op to Linalg in one step.