Hi,
I have a use case converting tosa binary ops to linalg and found the rank of its argument could get shrunk in the TosaToLinalg conversion.
For example,
- Starting from,
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
%2 = memref.alloc() : memref<1x8x8x8xf32>
...
%6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
%7 = "tosa.add"(%6, %arg2) : (tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32>
return %7 : tensor<1x8x8x8xf32>
}
- // -----// IR Dump After TosaToLinalg //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
%2 = memref.alloc() : memref<1x8x8x8xf32>
...
%6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
%7 = linalg.init_tensor [1, 8, 8, 8] : tensor<1x8x8x8xf32>
%8 = tensor.collapse_shape %arg2 [[0, 1, 2], [3]] : tensor<1x1x1x8xf32> into tensor<1x8xf32>
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %8 : tensor<1x8x8x8xf32>, tensor<1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%10 = arith.addf %arg3, %arg4 : f32
linalg.yield %10 : f32
} -> tensor<1x8x8x8xf32>
return %9 : tensor<1x8x8x8xf32>
}
This leads further reshapes
- // -----// IR Dump After LinalgElementwiseOpFusion //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
%2 = memref.alloc() : memref<1x8x8x8xf32>
...
%6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
%7 = tensor.expand_shape %6 [[0, 1, 2], [3], [4], [5]] : tensor<1x8x8x8xf32> into tensor<1x1x1x8x8x8xf32>
%8 = linalg.init_tensor [1, 1, 1, 8, 8, 8] : tensor<1x1x1x8x8x8xf32>
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7, %arg2 : tensor<1x1x1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%8 : tensor<1x1x1x8x8x8xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%11 = arith.addf %arg3, %arg4 : f32
linalg.yield %11 : f32
} -> tensor<1x1x1x8x8x8xf32>
%10 = tensor.collapse_shape %9 [[0, 1, 2], [3], [4], [5]] : tensor<1x1x1x8x8x8xf32> into tensor<1x8x8x8xf32>
return %10 : tensor<1x8x8x8xf32>
}
Nothing is wrong with it there but I’m wondering if these shape changes are necessary.
Elementwise binary operators in Tosa requires to have the same ranked argument and has dim_size=1 for the broadcast dims, so above applies to all the broadcast arguments.
Instead, can we keep the given original rank by map the dims to 0?
,as below.
- // -----// IR Dump After TosaToLinalg //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
%2 = memref.alloc() : memref<1x8x8x8xf32>
...
%6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
%7 = linalg.init_tensor [1, 8, 8, 8] : tensor<1x8x8x8xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %arg2 : tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%9 = arith.addf %arg3, %arg4 : f32
linalg.yield %9 : f32
} -> tensor<1x8x8x8xf32>
return %8 : tensor<1x8x8x8xf32>
}
- // -----// IR Dump After LinalgElementwiseOpFusion //----- //
func.func @test_fusion(%arg0: tensor<1x8x8x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<1x1x1x8xf32>) -> tensor<1x8x8x8xf32> {
...
%2 = memref.alloc() : memref<1x8x8x8xf32>
...
%6 = bufferization.to_tensor %2 : memref<1x8x8x8xf32>
%7 = linalg.init_tensor [1, 8, 8, 8] : tensor<1x8x8x8xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %arg2 : tensor<1x8x8x8xf32>, tensor<1x1x1x8xf32>) outs(%7 : tensor<1x8x8x8xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%9 = arith.addf %arg3, %arg4 : f32
linalg.yield %9 : f32
} -> tensor<1x8x8x8xf32>
return %8 : tensor<1x8x8x8xf32>
}
So we don’t need any change in the shape or extra operations to handle the change.