[TOSA][TosaToLinalg] Is it necessary to fold broadcast dims when convert to Linalg?

Hi,

I have a use case converting tosa binary ops to linalg and found the rank of its argument could get shrunk in the TosaToLinalg conversion.

For example,

  1. Starting from,
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32>  {
...
  %2 = memref.alloc() : memref<1x8x8x8xf32>
...
  %6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
  %7 = "tosa.add"(%6, %arg2) : (tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32>
  return %7 : tensor<1x8x8x8xf32>
}
  1. // -----// IR Dump After TosaToLinalg //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32>  {
...
  %2 = memref.alloc() : memref<1x8x8x8xf32>
...
  %6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
  %7 = linalg.init_tensor [1, 8, 8, 8] : tensor<1x8x8x8xf32>
  %8 = tensor.collapse_shape %arg2 [[0, 1, 2], [3]] : tensor<1x1x1x8xf32> into tensor<1x8xf32>
  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %8 : tensor<1x8x8x8xf32>, tensor<1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
    %10 = arith.addf %arg3, %arg4 : f32
    linalg.yield %10 : f32
  } -> tensor<1x8x8x8xf32>
  return %9 : tensor<1x8x8x8xf32>
}

This leads further reshapes

  1. // -----// IR Dump After LinalgElementwiseOpFusion //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32>  {
...
  %2 = memref.alloc() : memref<1x8x8x8xf32>
...
  %6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
  %7 = tensor.expand_shape %6 [[0, 1, 2], [3], [4], [5]] : tensor<1x8x8x8xf32> into tensor<1x1x1x8x8x8xf32>
  %8 = linalg.init_tensor [1, 1, 1, 8, 8, 8] : tensor<1x1x1x8x8x8xf32>
  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7, %arg2 : tensor<1x1x1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%8 : tensor<1x1x1x8x8x8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
    %11 = arith.addf %arg3, %arg4 : f32
    linalg.yield %11 : f32
  } -> tensor<1x1x1x8x8x8xf32>
  %10 = tensor.collapse_shape %9 [[0, 1, 2], [3], [4], [5]] : tensor<1x1x1x8x8x8xf32> into tensor<1x8x8x8xf32>
  return %10 : tensor<1x8x8x8xf32>
}

Nothing is wrong with it there but I’m wondering if these shape changes are necessary.
Elementwise binary operators in Tosa requires to have the same ranked argument and has dim_size=1 for the broadcast dims, so above applies to all the broadcast arguments.

Instead, can we keep the given original rank by map the dims to 0?
,as below.

  1. // -----// IR Dump After TosaToLinalg //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
  %2 = memref.alloc() : memref<1x8x8x8xf32>
...
  %6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
  %7 = linalg.init_tensor [1, 8, 8, 8] : tensor<1x8x8x8xf32>
  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %arg2 : tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
    %9 = arith.addf %arg3, %arg4 : f32
    linalg.yield %9 : f32
  } -> tensor<1x8x8x8xf32>
  return %8 : tensor<1x8x8x8xf32>
}
  1. // -----// IR Dump After LinalgElementwiseOpFusion //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
  %2 = memref.alloc() : memref<1x8x8x8xf32>
...
  %6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
  %7 = linalg.init_tensor [1, 8, 8, 8] : tensor<1x8x8x8xf32>
  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %arg2 : tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
    %9 = arith.addf %arg3, %arg4 : f32
    linalg.yield %9 : f32
  } -> tensor<1x8x8x8xf32>
  return %8 : tensor<1x8x8x8xf32>
}

So we don’t need any change in the shape or extra operations to handle the change.

I dont know much about the decisions made at the Tosa to Linalg level, but at the Elementwise op fusion level the goal is to operate on higher dimensional operations. That uncovers more opportunity for fusion. I can answer specific questions about this, but this discourse post has more details of the current state of things.

One of the things that causes a lot of pain is the rank-1 broadcasting semantics. This pass handles dropping unit-dimensions and represents broadcast using a more canonical representation at the Linalg level. For example, in the case above, instead of

 %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %arg2 : tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
    %9 = arith.addf %arg3, %arg4 : f32
    linalg.yield %9 : f32
  } -> tensor<1x8x8x8xf32>

A more natural representation is

 %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %arg2 : tensor<1x8x8x8xf32>, tensor<1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
    %9 = arith.addf %arg3, %arg4 : f32
    linalg.yield %9 : f32
  } -> tensor<1x8x8x8xf32>

This transformation is what the DropUnitDims pass achieves.

Thanks for the comment @MaheshRavishankar , that really helps.

To explain bit more about the use case,
We’re fusing linalg.generic with custom conv2d operation and experiencing problems with various different patterns on the in/out arguments.
DropUnitDims looks very helpful to canonicalize the in/out pattern around the linalg.generic as well, I’ll try to use it before aligning the linalg.generic to the custom fusion.

For now, I’ll send a patch to preserve the same rank in the TosaToLinalg as above, which doesn’t affect to the result of DropUnitDims. Tosa broadcast requires the same ranked input arguments as a canonical form and keeping that also helps in further lowering path.

Some more details would help. Broadly fusing linalg.generic with conv2d-like operations falls outside the perview of elementwise op fusion. These should be fused using tile + fuse. I can give you more pointers if this is of interest to you

Sure, that’d be very helpful. For now, we use a custom tile and reconfigure the linalg.generic to align into that tile. (I don’t think we’ll cover the fusion this time but hopefully you can find some related information in tomorrow’s ODM talk - Coordinate Transformations in AMD’s rocMLIR)

Our fusion is still in the early phase and we want to use more from the standard linalg fusion in the future. More information would help us to better understand how does upstream code expect to fuse and how our project can contribute back, thanks.

This method (llvm-project/TileUsingInterface.h at 92233159035d1b50face95d886901cf99035bd99 · llvm/llvm-project · GitHub) does tile and fuse of operations. It is geared to handle things like fusing linalg.matmul producer with a linalg.generic consumer that is elementwise operations. See some examples here (llvm-project/tile-and-fuse-using-interface.mlir at main · llvm/llvm-project · GitHub) . The same method should also work for fusing linalg.conv producer and consumer linalg.genericelementwise ops. There are no lit-tests explicitly testing these, but they do work (they are tested downstream). If you run into issue please report those and I can take a look.