Motivation and use cases of Folding Passes FoldWithProducerReshapeOp ByExpansion in LinalgFusionOfTensorOps

I see some Fold passes in Linalg FusionOnTensors.cpp and as I understand looks like the reshape linalg.tensor_expand_shape which expands the output tensor (3D->6D ) is folded, i.e the reshapes are moved above the linalg generic for expanding the input tensors rather output tensor, So as linalg can be provided with tensors of higher dimensions.

I want to understand the motivation and uses for this tensor expand pass and how it helps HW and networks perform better.

// *** IR Dump Before LinalgFusionOfTensorOps ***
#map0 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
module  {
 func @reshape_as_consumer_permutation(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2x?x3x4x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%arg0 : tensor<?x?x?xf32>) {
   ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
     %2 = addf %arg2, %arg3 : f32
     linalg.yield %2 : f32
   } -> tensor<?x?x?xf32>
   %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
   return %1 : tensor<?x2x?x3x4x?xf32>
 }
}

// *** IR Dump After LinalgFusionOfTensorOps ***
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
module  {
 func @reshape_as_consumer_permutation(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2x?x3x4x?xf32> {
   %c0 = constant 0 : index
   %c2 = constant 2 : index
   %c5 = constant 5 : index
   %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
   %1 = linalg.tensor_expand_shape %arg1 [[0, 1, 2], [3]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
   %2 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
   %3 = memref.dim %2, %c0 : tensor<?x2x?x3x4x?xf32>
   %4 = memref.dim %2, %c2 : tensor<?x2x?x3x4x?xf32>
   %5 = memref.dim %2, %c5 : tensor<?x2x?x3x4x?xf32>
   %6 = linalg.init_tensor [%3, 2, %4, 3, 4, %5] : tensor<?x2x?x3x4x?xf32>
   %7 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0, %1 : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>) outs(%6 : tensor<?x2x?x3x4x?xf32>) {
   ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
     %8 = addf %arg2, %arg3 : f32
     linalg.yield %8 : f32
   } -> tensor<?x2x?x3x4x?xf32>
   return %7 : tensor<?x2x?x3x4x?xf32>
 }
}

mlir-opt -print-ir-before-all -print-ir-after-all -linalg-fusion-for-tensor-ops=“allow-folding-unit-dim-reshapes=false” consumer_reshape.mlir

It would be hard to answer how this specific approach “helps HW perform better”. In the progressive lowering, this is very high up in the translation. This is meant to assist in fusing operations operating at tensor level into a “single operation”. The assumption being this single operation will eventually be executed in a “fused” manner, i.e. in a single CUDA kernel for example.

To give a specific example of where this helps, take the following example

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @fusion_example(
    %arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
    %arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
  %d0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
  %init1 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
  %0 = linalg.generic {
      iterator_types = ["parallel", "parallel"],
      indexing_maps = [#map0, #map0]}
      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%init1 : tensor<?x?xf32>) {
      ^bb0(%arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>, %arg5 : tensor<?x?xf32>):
        %0 = addf %arg3, %arg4 : f32
        linalg.yield %0 : f32
    } -> tensor<?x?xf32>
  %1 = linalg.tensor_expand_shape %0 [[0, 1], [2]] :
       tensor<?x?xf32> into tensor<?x?x?xf32>
  %2 = memref.dim %arg2, %c0 : tensor<?x?x?xf32>
  %3 = memref.dim %arg2, %c1 : tensor<?x?x?xf32>
  %4 = memref.dim %arg2, %c2 : tensor<?x?x?xf32>
  %5 = linalg.init_tensor [%2, %3, %4] : tensor<?x?x?xf32>
  %6 = linalg.generic %1 {
      iterator_types = ["parallel", "parallel", "parallel"]
      indexing_maps = [#map1, #map1, #map1]}
      ins(%1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
      outs(%5 : tensor<?x?x?xf32>) {
      ^bb0(%arg3 : tensor<?x?x?xf32>, %arg4 : tensor<?x?x?xf32>, %arg5 : tensor<?x?x?xf32):
        %0 = addf %arg3, %arg4 : f32
        linalg.yield %0 : f32
    } -> tensor<?x?x?xf32>
   return %6 : tensor<?x?xf32>
}

One way to fuse this is to increase the dimensionality of the first linalg.generic. This would result in indexing maps of the form

affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1, d2)>` 

for the operands. This is not desirable since preserving the affine map to be invertible is essential for composition with other transformations (like fusion and maybe tiling as well). To avoid this you can also increase the dimensionality of the tensor operands, by reshaping the 2D tensor to a 3D tensor so that all indexing maps are of the form

affine_map<(d0, d1, d2) -> (d0, d1, d2)>

At tensor level this is a perfectly valid transformation to do. With this transformation, both linalg.generic ops become 3D, the intermediate linalg.tensor_expand_shape becomes dead, and the operations can fuse. This does introduce reshapes at the operands of the first linalg.generic operation. But if you look at what this transformation did, it moved the linalg.tensor_expand_shape from in between the two operations, to above the first operation (allowing the two operations to fuse). Applied in a fixed-point manner you could expect the linalg.tensor_expand_shape to propogate all the way to the “top” of the function.

(Conversely, the same reasoning introduces linalg.tensor_collapse_shape at the results, but a fixed point iteration propogates these all the way to the “bottom” of the function).

So this is all the theory, but it does work well in practice. It is used in IREEs compilation flow and it does produce superior fusion. The remaining thing to handle is the reshapes at the boundaries. The reshapes (i.e. the linalg.tensor_expand/collapse_shape) either split consecutive dimensions or collapse consecutive dimensions. When you bufferize these you effectively have a linearized memory representation and the expand and collapse just becomes metadata, and effectively a no-op.

I hope that conveyed some of the reasoning behind this, but I am happy to explain more.

Thanks Mahesh,
I can see that above example use case seems to justify the need and motivation of folding the reshapes in both directions.
“Applied in a fixed-point manner / fixed point iteration” What exactly does it means, This would be a natural progression, i.e once the desired condition is met and pass is triggered the expand will start to move up recursively and similarly for collapse.

Just an observation when this reshape expand %1 was inserted to couple the 2D output with 3D input, wouldn’t it had been possible to decide if the 2 ops were fusible and a fold could have been initiated in-place, or probably it’s just a matter of design choice taken by MLIR.

%1 = linalg.tensor_expand_shape %0 [[0, 1], [2]] : tensor<?x?xf32> into tensor<?x?x?xf32>

By fixed-point iteration I meant the same pattern needs to be applied repeatedly till it cant be applied again. I am not sure I fully process the question above.

That is what you would get as “input”. This is to get a more “canonical” representation of the program, by fusing trivially fusable operation as much as possible. Reshapes interspersed in the program interfere with the fusion. This pass is essentially taking the reshape out of the way.