I have a question on approaches to fuse specific loop iterations in LinAlg. Any pointer is appreciated. For example, in the code snippet below, I’d like to only fuse the outer most dimension i.e. 2x but not the rest of iterations. What would be the recommended way of doing this in LinAlg if supported? I see one possibility as tile-and-fuse the iteration with tile size 1. But driving fusion through tiling does not feel very natural so I want to understand if there is a better approach for it in LinAlg’s design. Thanks.
- Before fusion
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func @test_before_fusion(%arg0: tensor<2x512xbf16>, %arg1: tensor<2x512xbf16>, %arg2: tensor<2x512x512xbf16>, %arg3: tensor<2x512x512xbf16>) -> tensor<2x512x512xbf16> {
%0 = linalg.init_tensor [2, 512] : tensor<2x512xbf16>
%1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<2x512xbf16>, tensor<2x512xbf16>) outs(%0 :tensor<2x512xbf16>) {
^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
%6 = arith.addf %arg4, %arg5 : bf16
linalg.yield %6 : bf16
} -> tensor<2x512xbf16>
%2 = linalg.init_tensor [2, 512, 512] : tensor<2x512x512xbf16>
%3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2, %arg3 : tensor<2x512x512xbf16>,tensor<2x512x512xbf16>) outs(%2 : tensor<2x512x512xbf16>) {
^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
%6 = arith.subf %arg4, %arg5 : bf16
linalg.yield %6 : bf16
} -> tensor<2x512x512xbf16>
%4 = linalg.init_tensor [2, 512, 512] : tensor<2x512x512xbf16>
%5 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %3 : tensor<2x512xbf16>, tensor<2x512x512xbf16>) outs(%4: tensor<2x512x512xbf16>) {
^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
%6 = arith.mulf %arg4, %arg5 : bf16
linalg.yield %6 : bf16
} -> tensor<2x512x512xbf16>
return %5 : tensor<2x512x512xbf16>
}
- After fusion
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
func @test_after_fusion(%arg0: tensor<2x512xbf16>, %arg1: tensor<2x512xbf16>, %arg2: tensor<2x512x512xbf16>, %arg3: tensor<2x512x512xbf16>) -> tensor<2x512x512xbf16> {
%0 = linalg.init_tensor [2, 512, 512] : tensor<2x512x512xbf16>
%cst = arith.constant 0.000000e+00 : bf16
%1 = linalg.fill(%cst, %0) : bf16, tensor<2x512x512xbf16> -> tensor<2x512x512xbf16>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%2 = scf.for %arg4 = %c0 to %c2 step %c1 iter_args(%arg5 = %1) -> (tensor<2x512x512xbf16>) {
%3 = tensor.extract_slice %arg0[%arg4, 0] [1, 512] [1, 1] : tensor<2x512xbf16> to tensor<512xbf16>
%4 = tensor.extract_slice %arg1[%arg4, 0] [1, 512] [1, 1] : tensor<2x512xbf16> to tensor<512xbf16>
%5 = linalg.init_tensor [512] : tensor<512xbf16>
%6 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} ins(%3, %4 : tensor<512xbf16>, tensor<512xbf16>) outs(%5 : tensor<512xbf16>) {
^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
%14 = arith.addf %arg6, %arg7 : bf16
linalg.yield %14 : bf16
} -> tensor<512xbf16>
%7 = tensor.extract_slice %arg2[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<2x512x512xbf16> to tensor<512x512xbf16>
%8 = tensor.extract_slice %arg3[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<2x512x512xbf16> to tensor<512x512xbf16>
%9 = linalg.init_tensor [512, 512] : tensor<512x512xbf16>
%10 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%7, %8 : tensor<512x512xbf16>, tensor<512x512xbf16>) outs(%9 : tensor<512x512xbf16>) {
^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
%14 = arith.subf %arg6, %arg7 : bf16
linalg.yield %14 : bf16
} -> tensor<512x512xbf16>
%11 = linalg.init_tensor [512, 512] : tensor<512x512xbf16>
%12 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%6, %10 : tensor<512xbf16>, tensor<512x512xbf16>) outs(%11 : tensor<512x512xbf16>) {
^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
%14 = arith.mulf %arg6, %arg7 : bf16
linalg.yield %14 : bf16
} -> tensor<512x512xbf16>
%13 = tensor.insert_slice %12 into %arg5[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<512x512xbf16> into tensor<2x512x512xbf16>
scf.yield %13 : tensor<2x512x512xbf16>
}
return %2 : tensor<2x512x512xbf16>
}