Lack of support when lowering mhlo.reduce to Linalg

I’m working on upstream the lowering in IREE to MHLO. When lowering mhlo.reduce to Linalg, @herhut and I had a discussion. After prototyping it, I hit some issues. I’m wondering if they are issues, and if they should be fixed.

The goal is to lower a reduce op to Linalg generic, e.g.,

func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32> {
  %0 = "mhlo.reduce"(%arg0, %arg1) ({
  ^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
    %1 = mhlo.add %arg3, %arg4 : tensor<i32>
    "mhlo.return"(%1) : (tensor<i32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32>
  return %0 : tensor<?xi32>
}

to

...

%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xi32>) outs(%2 : tensor<?xi32>) {
    ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
      %4 = addi %arg2, %arg3 : i32
      linalg.yield %4 : i32
    } -> tensor<?xi32>

...

The generic op will clone the region of reduce op, rewrite block args in scalar form, and lower mhlo ops (within the region) to std ops.

This is a big step, so we tried to break it into pieces.

  1. Create a linalg.generic op and clone the region. In this step, the body starts with a scalar_to_0d_tensor operation that turns the scalar into a 0d tensor, then the normal HLO follows.
%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xi32>) outs(%2 : tensor<?xi32>) {
    ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
      %4 = scalar_to_0d_tensor %arg2 : i32 to tensor<i32>
      %5 = scalar_to_0d_tensor %arg3 : i32 to tensor<i32>
      %6  = mhlo.add %4, %5 : tensor<i32>
      %7 = tensor.extract %6[] : tensor<i32>
      linalg.yield %7 : i32
    } -> tensor<?xi32>
  1. Apply the mhlo to scalar standard lowering to the body of these regions. If all operations could be converted, we get extract(scalar_to_0d_tensor(x)) pairs that canonicalize to x.

The issue in step one is that we are lacking a tensor_reshape like operation. I tried tensor::CastOp and it complains about 'tensor<1xf32>' and result type 'tensor<f32>' are cast incompatible. One workaround is to cast it twice, but I don’t think this is the way to go.

%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xi32>) outs(%2 : tensor<?xi32>) {
    ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
      %4 = tensor.from_elements %arg2 : tensor<1xi32>
      %5 = tensor.cast %4 : tensor<1xi32> to tensor<*xi32>
      %6 = tensor.cast %5 : tensor<*xi32> to tensor<i32>
      %7 = tensor.from_elements %arg3 : tensor<1xi32>
      %8 = tensor.cast %7 : tensor<1xi32> to tensor<*xi32>
      %9 = tensor.cast %8 : tensor<*xi32> to tensor<i32>
      %10 = mhlo.add %6, %9 : tensor<i32>
      %11 = tensor.extract %10[] : tensor<i32>
      linalg.yield %11 : i32
    } -> tensor<?xi32>

The issue in step two is the order in which rewrites happen. There could be many mhlo ops in a function. I’m not sure if there is a guaranteed order. If configuring legality by saying the mhlo is legal inside a generic op, then this would work if the framework guarantees a pre-order rewrite traversal. Otherwise, the system could rewrite the “inner” hlos to Linalg generic first. Because they would still be inside another mhlo and not yet in a Linalg generic.

In the end, I landed the commit with the original pass – convert the op to Linalg in one step.

These are exactly a few among the larger issues I had pointed out with the design of reduce like ops in MHLO and MLHLO about six months ago on this thread: https://groups.google.com/a/tensorflow.org/g/mlir/c/Ip55os0xgfU/m/iI2URCzqAgAJ
The design of mlhlo reduce regions is clearly broken and there was consensus on that. For mhlo reduce ops, which is relevant to this thread, if the reduction function that is modeled by the regions of these ops is actually meant to capture computation on scalars (elements) of tensors, we should simply be using elemental types instead of force-fitting 0-d tensors! Here is my comment from 15-Jul-2020 on that thread:

If you don’t want to stay 1:1 with “all tensor values” dialects, you could use scalar elt types, and it isn’t really an issue to do the type and op conversion when you go from XLA to mhlo one. And for all the transformations and rewrites that you want to perform on mhlo and mlhlo, will you benefit from scalar types and std ops in reduction functions or 0-d tensor ops? And if you want to unify with TCP and Linalg at some point, the latter are going to (and already) use only std ops and scalar types in their regions - not 0-d tensor wrappers.

Of course, if you get rid of 0-d tensors on the HLO blocks as well, you won’t need any conversion at all on the block (neither for types nor block signature) except for the terminator replacement FWIW.

That said, even with the current form of mhlo.reduce, I really think you don’t need any of the cast or cast like abc_to_xyz ops to go from form A to form B you show above. The block signature conversions when coupled with mhlo on 0-d tensor to std dialect op conversions will yield the desired results. These were exactly the kind of conversions I had implemented to go from mhlo to mlhlo - the region is converted to all scalar elemental types + std dialect ops in the fixed mlhlo design (these were the pending commits to upstream that were referred to in that thread). All of this works for all the reduce like ops as well as select and scattter. You shouldn’t be converting arithmetic ops on 0-d tensors to linalg.generic ones if that’s what you meant.

Yes, I agree that using element types (instead of 0d tensor) would be much simpler. We won’t have to rewrite the region except the mhlo.yield. We can just inline the region, get the terminator, and replace it with linalg.yield.