I see some Fold passes in Linalg FusionOnTensors.cpp and as I understand looks like the reshape linalg.tensor_expand_shape which expands the output tensor (3D->6D ) is folded, i.e the reshapes are moved above the linalg generic for expanding the input tensors rather output tensor, So as linalg can be provided with tensors of higher dimensions.
I want to understand the motivation and uses for this tensor expand pass and how it helps HW and networks perform better.
// *** IR Dump Before LinalgFusionOfTensorOps ***
#map0 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
module {
func @reshape_as_consumer_permutation(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2x?x3x4x?xf32> {
%0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%arg0 : tensor<?x?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%2 = addf %arg2, %arg3 : f32
linalg.yield %2 : f32
} -> tensor<?x?x?xf32>
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
return %1 : tensor<?x2x?x3x4x?xf32>
}
}
// *** IR Dump After LinalgFusionOfTensorOps ***
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
module {
func @reshape_as_consumer_permutation(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2x?x3x4x?xf32> {
%c0 = constant 0 : index
%c2 = constant 2 : index
%c5 = constant 5 : index
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
%1 = linalg.tensor_expand_shape %arg1 [[0, 1, 2], [3]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
%2 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
%3 = memref.dim %2, %c0 : tensor<?x2x?x3x4x?xf32>
%4 = memref.dim %2, %c2 : tensor<?x2x?x3x4x?xf32>
%5 = memref.dim %2, %c5 : tensor<?x2x?x3x4x?xf32>
%6 = linalg.init_tensor [%3, 2, %4, 3, 4, %5] : tensor<?x2x?x3x4x?xf32>
%7 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0, %1 : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>) outs(%6 : tensor<?x2x?x3x4x?xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%8 = addf %arg2, %arg3 : f32
linalg.yield %8 : f32
} -> tensor<?x2x?x3x4x?xf32>
return %7 : tensor<?x2x?x3x4x?xf32>
}
}
mlir-opt -print-ir-before-all -print-ir-after-all -linalg-fusion-for-tensor-ops=“allow-folding-unit-dim-reshapes=false” consumer_reshape.mlir