I am trying to use the transform dialect to fuse a separable convolution which is implemented as two linalg generics.
Some example IR below, where I have two such convolutions one after the other for a total of four generics.
I had (you can see in the IR) added an attribute to the first of each of the pairs of generics in the assumption that it would help me match but now I can’t see a way to use it.
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
module @jit_execute_model attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x540x960x1xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<1x534x954x1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @compiled_predict_step(%arg0) : (tensor<1x540x960x1xf32>) -> tensor<1x534x954x1xf32>
return %0 : tensor<1x534x954x1xf32>
}
func.func private @compiled_predict_step(%arg0: tensor<1x540x960x1xf32>) -> tensor<1x534x954x1xf32> {
%cst = arith.constant dense<[[[1.000000e+00]], [[2.000000e+00]], [[3.000000e+00]], [[4.000000e+00]], [[5.000000e+00]]]> : tensor<5x1x1xf32>
%cst_0 = arith.constant dense<[[[1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00], [5.000000e+00]]]> : tensor<1x5x1xf32>
%cst_1 = arith.constant dense<[[[1.000000e+00], [2.000000e+00], [3.000000e+00]]]> : tensor<1x3x1xf32>
%cst_2 = arith.constant dense<[[[1.000000e+00]], [[2.000000e+00]], [[3.000000e+00]]]> : tensor<3x1x1xf32>
%cst_3 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x540x958x1xf32>
%1 = linalg.fill ins(%cst_3 : f32) outs(%0 : tensor<1x540x958x1xf32>) -> tensor<1x540x958x1xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %cst_1 : tensor<1x540x960x1xf32>, tensor<1x3x1xf32>) outs(%1 : tensor<1x540x958x1xf32>) attrs = {separable_blur = true} {
^bb0(%in: f32, %in_4: f32, %out: f32):
%12 = arith.mulf %in, %in_4 : f32
%13 = arith.addf %12, %out : f32
linalg.yield %13 : f32
} -> tensor<1x540x958x1xf32>
%3 = tensor.empty() : tensor<1x538x958x1xf32>
%4 = linalg.fill ins(%cst_3 : f32) outs(%3 : tensor<1x538x958x1xf32>) -> tensor<1x538x958x1xf32>
%5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%2, %cst_2 : tensor<1x540x958x1xf32>, tensor<3x1x1xf32>) outs(%4 : tensor<1x538x958x1xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%12 = arith.mulf %in, %in_4 : f32
%13 = arith.addf %12, %out : f32
linalg.yield %13 : f32
} -> tensor<1x538x958x1xf32>
%6 = tensor.empty() : tensor<1x538x954x1xf32>
%7 = linalg.fill ins(%cst_3 : f32) outs(%6 : tensor<1x538x954x1xf32>) -> tensor<1x538x954x1xf32>
%8 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%5, %cst_0 : tensor<1x538x958x1xf32>, tensor<1x5x1xf32>) outs(%7 : tensor<1x538x954x1xf32>) attrs = {separable_blur = true} {
^bb0(%in: f32, %in_4: f32, %out: f32):
%12 = arith.mulf %in, %in_4 : f32
%13 = arith.addf %12, %out : f32
linalg.yield %13 : f32
} -> tensor<1x538x954x1xf32>
%9 = tensor.empty() : tensor<1x534x954x1xf32>
%10 = linalg.fill ins(%cst_3 : f32) outs(%9 : tensor<1x534x954x1xf32>) -> tensor<1x534x954x1xf32>
%11 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%8, %cst : tensor<1x538x954x1xf32>, tensor<5x1x1xf32>) outs(%10 : tensor<1x534x954x1xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%12 = arith.mulf %in, %in_4 : f32
%13 = arith.addf %12, %out : f32
linalg.yield %13 : f32
} -> tensor<1x534x954x1xf32>
return %11 : tensor<1x534x954x1xf32>
}
}
Before I tried to integrate this in a larger pipeline I was matching a simple pipeline with two functions inside like this:
transform.sequence failures(propagate){
^bb1(%variant_op: !transform.any_op):
%ops = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!transform.any_op) -> !transform.any_op
%conv0, %conv1 = transform.split_handle %ops
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
Can anybody guide me how to (using my attribute or otherwise) how to match my two pairs of convolutions so I can fuse them together?
I think at this point it may not make any difference but I intend to use this transform in IREE.