Matching Separable Convolutions using Transform Dialect (with attributes?)

I am trying to use the transform dialect to fuse a separable convolution which is implemented as two linalg generics.

Some example IR below, where I have two such convolutions one after the other for a total of four generics.

I had (you can see in the IR) added an attribute to the first of each of the pairs of generics in the assumption that it would help me match but now I can’t see a way to use it.

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
module @jit_execute_model attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x540x960x1xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<1x534x954x1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @compiled_predict_step(%arg0) : (tensor<1x540x960x1xf32>) -> tensor<1x534x954x1xf32>
    return %0 : tensor<1x534x954x1xf32>
  }
  func.func private @compiled_predict_step(%arg0: tensor<1x540x960x1xf32>) -> tensor<1x534x954x1xf32> {
    %cst = arith.constant dense<[[[1.000000e+00]], [[2.000000e+00]], [[3.000000e+00]], [[4.000000e+00]], [[5.000000e+00]]]> : tensor<5x1x1xf32>
    %cst_0 = arith.constant dense<[[[1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00], [5.000000e+00]]]> : tensor<1x5x1xf32>
    %cst_1 = arith.constant dense<[[[1.000000e+00], [2.000000e+00], [3.000000e+00]]]> : tensor<1x3x1xf32>
    %cst_2 = arith.constant dense<[[[1.000000e+00]], [[2.000000e+00]], [[3.000000e+00]]]> : tensor<3x1x1xf32>
    %cst_3 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x540x958x1xf32>
    %1 = linalg.fill ins(%cst_3 : f32) outs(%0 : tensor<1x540x958x1xf32>) -> tensor<1x540x958x1xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %cst_1 : tensor<1x540x960x1xf32>, tensor<1x3x1xf32>) outs(%1 : tensor<1x540x958x1xf32>) attrs =  {separable_blur = true} {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x540x958x1xf32>
    %3 = tensor.empty() : tensor<1x538x958x1xf32>
    %4 = linalg.fill ins(%cst_3 : f32) outs(%3 : tensor<1x538x958x1xf32>) -> tensor<1x538x958x1xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%2, %cst_2 : tensor<1x540x958x1xf32>, tensor<3x1x1xf32>) outs(%4 : tensor<1x538x958x1xf32>) {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x538x958x1xf32>
    %6 = tensor.empty() : tensor<1x538x954x1xf32>
    %7 = linalg.fill ins(%cst_3 : f32) outs(%6 : tensor<1x538x954x1xf32>) -> tensor<1x538x954x1xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%5, %cst_0 : tensor<1x538x958x1xf32>, tensor<1x5x1xf32>) outs(%7 : tensor<1x538x954x1xf32>) attrs =  {separable_blur = true} {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x538x954x1xf32>
    %9 = tensor.empty() : tensor<1x534x954x1xf32>
    %10 = linalg.fill ins(%cst_3 : f32) outs(%9 : tensor<1x534x954x1xf32>) -> tensor<1x534x954x1xf32>
    %11 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%8, %cst : tensor<1x538x954x1xf32>, tensor<5x1x1xf32>) outs(%10 : tensor<1x534x954x1xf32>) {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x534x954x1xf32>
    return %11 : tensor<1x534x954x1xf32>
  }
}

Before I tried to integrate this in a larger pipeline I was matching a simple pipeline with two functions inside like this:

transform.sequence failures(propagate){
^bb1(%variant_op: !transform.any_op):
  %ops = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!transform.any_op) -> !transform.any_op

  %conv0, %conv1 = transform.split_handle %ops
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

Can anybody guide me how to (using my attribute or otherwise) how to match my two pairs of convolutions so I can fuse them together?

I think at this point it may not make any difference but I intend to use this transform in IREE.

To add a little more to this I was trying to continue to use transform.structured.match but this seems to me to not align with the tutorial I found here:

In particular that transform.structured.match does not have MatchOpInterface, and when looking through transform.match.* I didn’t see anything where I can make use of my attribute?

One more thing I have noticed now is that IREE both eventually strips my attributes anyway and does some dim flattening to the real IR I have to eventually transform looks like this

util.func public @main(%arg0: !hal.buffer_view {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (!hal.buffer_view {jax.result_info = "", mhlo.layout_mode = "default"}) attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @main(%input0: tensor<1x540x960x1xf32> {mhlo.layout_mode = \22default\22, mhlo.sharding = \22{replicated}\22}) -> (%output0: tensor<1x534x954x1xf32> {jax.result_info = \22\22, mhlo.layout_mode = \22default\22})"}} {
  %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32> loc(unknown)
  %cst_0 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> loc(unknown)
  %cst_1 = arith.constant 0.000000e+00 : f32 loc(unknown)
  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<1x540x960x1xf32> loc("bob.mlir":5:3)
  %collapsed = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<1x540x960x1xf32> into tensor<540x960xf32> loc(callsite(callsite("bob.mlir":17:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %1 = tensor.empty() : tensor<540x958xf32> loc(callsite(callsite("bob.mlir":16:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<540x958xf32>) -> tensor<540x958xf32> loc(callsite(callsite("bob.mlir":17:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1 + d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed, %cst_0 : tensor<540x960xf32>, tensor<3xf32>) outs(%2 : tensor<540x958xf32>) {
  ^bb0(%in: f32 loc("bob.mlir":18:10), %in_2: f32 loc("bob.mlir":18:20), %out: f32 loc("bob.mlir":18:32)):
    %14 = arith.mulf %in, %in_2 : f32 loc(callsite(callsite("bob.mlir":19:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    %15 = arith.addf %14, %out : f32 loc(callsite(callsite("bob.mlir":20:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    linalg.yield %15 : f32 loc(callsite(callsite("bob.mlir":21:7 at "bob.mlir":6:10) at "bob.mlir":5:3))
  } -> tensor<540x958xf32> loc(callsite(callsite("bob.mlir":17:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %4 = tensor.empty() : tensor<538x958xf32> loc(callsite(callsite("bob.mlir":24:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %5 = linalg.fill ins(%cst_1 : f32) outs(%4 : tensor<538x958xf32>) -> tensor<538x958xf32> loc(callsite(callsite("bob.mlir":25:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d2, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %cst_0 : tensor<540x958xf32>, tensor<3xf32>) outs(%5 : tensor<538x958xf32>) {
  ^bb0(%in: f32 loc("bob.mlir":26:10), %in_2: f32 loc("bob.mlir":26:20), %out: f32 loc("bob.mlir":26:32)):
    %14 = arith.mulf %in, %in_2 : f32 loc(callsite(callsite("bob.mlir":27:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    %15 = arith.addf %14, %out : f32 loc(callsite(callsite("bob.mlir":28:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    linalg.yield %15 : f32 loc(callsite(callsite("bob.mlir":29:7 at "bob.mlir":6:10) at "bob.mlir":5:3))
  } -> tensor<538x958xf32> loc(callsite(callsite("bob.mlir":25:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %7 = tensor.empty() : tensor<538x954xf32> loc(callsite(callsite("bob.mlir":32:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %8 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<538x954xf32>) -> tensor<538x954xf32> loc(callsite(callsite("bob.mlir":33:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1 + d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6, %cst : tensor<538x958xf32>, tensor<5xf32>) outs(%8 : tensor<538x954xf32>) {
  ^bb0(%in: f32 loc("bob.mlir":34:10), %in_2: f32 loc("bob.mlir":34:20), %out: f32 loc("bob.mlir":34:32)):
    %14 = arith.mulf %in, %in_2 : f32 loc(callsite(callsite("bob.mlir":35:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    %15 = arith.addf %14, %out : f32 loc(callsite(callsite("bob.mlir":36:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    linalg.yield %15 : f32 loc(callsite(callsite("bob.mlir":37:7 at "bob.mlir":6:10) at "bob.mlir":5:3))
  } -> tensor<538x954xf32> loc(callsite(callsite("bob.mlir":33:10 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %10 = tensor.empty() : tensor<534x954xf32> loc(callsite(callsite("bob.mlir":40:11 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %11 = linalg.fill ins(%cst_1 : f32) outs(%10 : tensor<534x954xf32>) -> tensor<534x954xf32> loc(callsite(callsite("bob.mlir":41:11 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d2, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%9, %cst : tensor<538x954xf32>, tensor<5xf32>) outs(%11 : tensor<534x954xf32>) {
  ^bb0(%in: f32 loc("bob.mlir":42:10), %in_2: f32 loc("bob.mlir":42:20), %out: f32 loc("bob.mlir":42:32)):
    %14 = arith.mulf %in, %in_2 : f32 loc(callsite(callsite("bob.mlir":43:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    %15 = arith.addf %14, %out : f32 loc(callsite(callsite("bob.mlir":44:13 at "bob.mlir":6:10) at "bob.mlir":5:3))
    linalg.yield %15 : f32 loc(callsite(callsite("bob.mlir":45:7 at "bob.mlir":6:10) at "bob.mlir":5:3))
  } -> tensor<534x954xf32> loc(callsite(callsite("bob.mlir":41:11 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %expanded = tensor.expand_shape %12 [[0, 1], [2, 3]] : tensor<534x954xf32> into tensor<1x534x954x1xf32> loc(callsite(callsite("bob.mlir":41:11 at "bob.mlir":6:10) at "bob.mlir":5:3))
  %13 = hal.tensor.export %expanded "output0" : tensor<1x534x954x1xf32> -> !hal.buffer_view loc("bob.mlir":5:3)
  util.return %13 : !hal.buffer_view loc("bob.mlir":5:3)
} loc("bob.mlir":5:3)


so maybe all I really can match on are the affine maps?

In IREE we have a transform op to match on dags, so you can match on the exact linalg op you are interested in. grep iree for cast_compatible_dag_from_root. Some examples:

    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {
    ^bb0(%lhs: tensor<2x?x?x1280xf16>, %rhs: tensor<3x3x1280x640xf16>, %out: tensor<2x64x64x640xf32>):
      %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> }
        ins(%lhs, %rhs : tensor<2x?x?x1280xf16>, tensor<3x3x1280x640xf16>)
        outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32>
    } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
      ^bb0(%lhs: tensor<?x?xf16>, %rhs: tensor<?x?xf16>, %out: tensor<?x?xf32>):
      %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                                            affine_map<(d0, d1, d2) -> (d1, d2)>,
                                            affine_map<(d0, d1, d2) -> (d0, d1)>],
                           iterator_types = ["parallel", "parallel", "reduction"]}
          ins(%lhs, %rhs : tensor<?x?xf16>, tensor<?x?xf16>) outs(%out : tensor<?x?xf32>) {
        ^bb0(%in: f16, %in_0: f16, %acc: f32):
          %8 = arith.extf %in : f16 to f32
          %9 = arith.extf %in_0 : f16 to f32
          %10 = arith.mulf %8, %9 : f32
          %11 = arith.addf %acc, %10 : f32
          linalg.yield %11 : f32
        } -> tensor<?x?xf32>
    } : (!transform.any_op) -> (!transform.any_value, !transform.any_value

Something like the following should work, based on the tutorial chapter cited:

module attributes {transform.with_named_sequence} {
  transform.named_sequence @match_separable_convolution(%entry: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_op) {

    %c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
    %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
    %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
    %c5 = transform.param.constant 5 : i64 -> !transform.param<i64>

    %matched_entry = transform.match.structured %entry : (!transform.any_op) -> (!transform.any_op) {
    ^bb0(%s: !transform.any_op):
      %batch, %oi, %oc, %fl, %ic, %depth, %strides, %dilutions = transform.match.structured.classify_convolution_dims %s : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
      %c12 = transform.merge_handles %c1, %c2 : !transform.param<i64>
      %c45 = transform.merge_handles %c4, %c5 : !transform.param<i64>
      %c11 = transform.merge_handles %c1, %c1 : !transform.param<i64>
      
      // Batch dimensions: 0
      transform.match.param.cmpi eq %batch, %c0 : !transform.param<i64>
      // Output image dimensions: 1, 2
      transform.match.param.cmpi eq %oi, %c12 : !transform.param<i64>
      // Filter loop dimensions: 4, 5
      transform.match.param.cmpi eq %fl, %c45 : !transform.param<i64>
      // Depth dimensions: 3
      transform.match.param.cmpi eq %depth, %c3 : !transform.param<i64>
      
      // No channels
      %num_ic = transform.num_associations %ic : (!transform.param<i64>) -> !transform.param<i64>
      transform.match.param.cmpi eq %num_ic, %c0 : !transform.param<i64>
      %num_oc = transform.num_associations %oc : (!transform.param<i64>) -> !transform.param<i64>
      transform.match.param.cmpi eq %num_oc, %c0 : !transform.param<i64>

      // No strides or dilutions.
      transform.match.param.cmpi eq %strides, %c11 : !transform.param<i64>
      transform.match.param.cmpi eq %dilutions, %c11 : !transform.param<i64>
      transform.match.structured.yield %s : !transform.any_op
    }

    %0 = transform.get_producer_of_operand %entry[0] : (!transform.any_op) -> !transform.any_op
    
    // TODO: similar dimension-matching structure as for %entry

    transform.yield %matched_entry, %0 : !transform.any_op, !transform.any_op
  }

  transform.named_sequence @__transform_main(%root: !transform.any_op) {
    %result:2 = transform.collect_matching @match_separable_convolution in %root : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    %num = transform.num_associations %result#0 : (!transform.any_op) -> !transform.param<i64>
    transform.debug.emit_param_as_remark %num : !transform.param<i64>
    transform.yield
  }
}

This is based on the generic convolution-like map matcher op that we have and asks for the forms specific to your convolution. It’s a bit verbose, so I only spelled out the conditions for the first generic from the producer-consumer pair. Note that this will currently detect three pairs, (%3, %6), (%6, %9), and (%9, %12 because their dimensional structures are identical.

You may add a matcher op checking for a specific attribute, nobody needed it enough to implement it. That being said, I discourage using attributes for matching since, as you have observed yourself, compiler passes tend to drop them unceremoniously.

transform.structured.match is a shorthand version of transform.collect_matching, there is no intention for it to be usable inside more advanced matchers.

Thanks both for these useful replies I will experiment a bit with both.

@kuhar

I have played a bit with your suggestion and it looks promising but I hit a bit of a block, I wonder can you help me?

I am passing IREE the flag
--iree-flow-dispatch-use-transform-dialect=iree-conv-flow-td-seq.mlir

with the flow spec

#map1 = affine_map<(d0, d1, d2) -> (d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0 + d2, d1)>

module  attributes { transform.with_named_sequence }  {

    transform.named_sequence @match_separable(%root: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_value) {
    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
          ^bb0(%arg0: tensor<540x960xf32>, %cst_1: tensor<3xf32>, %cst_2: tensor<3xf32>):
              %cst_3 = arith.constant 0.000000e+00 : f32
              %0 = tensor.empty() : tensor<540x958xf32>
              %1 = linalg.fill ins(%cst_3 : f32) outs(%0 : tensor<540x958xf32>) -> tensor<540x958xf32>
              %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %cst_1 : tensor<540x960xf32>, tensor<3xf32>) outs(%1 : tensor<540x958xf32>) {
              ^bb0(%in: f32, %in_4: f32, %out: f32):
                %12 = arith.mulf %in, %in_4 : f32
                %13 = arith.addf %12, %out : f32
                linalg.yield %13 : f32
              } -> tensor<540x958xf32>
              %3 = tensor.empty() : tensor<538x958xf32>
              %4 = linalg.fill ins(%cst_3 : f32) outs(%3 : tensor<538x958xf32>) -> tensor<538x958xf32>
              %5 = linalg.generic {indexing_maps = [#map3, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2, %cst_2 : tensor<540x958xf32>, tensor<3xf32>) outs(%4 : tensor<538x958xf32>) {
              ^bb0(%in: f32, %in_4: f32, %out: f32):
                %12 = arith.mulf %in, %in_4 : f32
                %13 = arith.addf %12, %out : f32
                linalg.yield %13 : f32
              } -> tensor<538x958xf32>
        } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
        transform.yield %ins, %outs : !transform.any_value, !transform.any_value
    }

    transform.named_sequence @create_dispatch_region(%ins: !transform.any_value {transform.readonly},
                                                 %out: !transform.any_value {transform.readonly}) {
        %root = transform.get_defining_op %out : (!transform.any_value) -> !transform.any_op
        %module = transform.util.get_nearest_symbol_table %root : (!transform.any_op) -> !transform.any_op
        %region_op = transform.iree.wrap_in_dispatch_region %root { generateWorkload = false } : (!transform.any_op) -> !transform.any_op
        //%region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %root into %region_op : (!transform.any_op, !transform.any_op) -> !transform.any_op
        transform.yield
    }


    transform.named_sequence @__transform_main(%module: !transform.any_op) {
        %funcs = transform.structured.match ops{["util.func"]} in %module : (!transform.any_op) -> !transform.any_op
            // For each function in the module, run the matcher on all contained
            // operations.
            transform.foreach %funcs : !transform.any_op {
              ^bb1(%func: !transform.any_op):
                transform.foreach_match in %func
                    // <matcher name> -> <rewriter name>
                    // Multiple matcher-action pairs can be specified comma separated,
                    // here we are only doing a single kind of match and replace.
                    @match_separable -> @create_dispatch_region
                  : (!transform.any_op) -> (!transform.any_op)
            }
            // Cleanup leftover dead code; cast_and_call does not do replacement, only
            // rewires uses.
            transform.yield
    }
}

with a sample payload

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
module @jit_execute_model attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x540x960x1xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<1x534x954x1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @compiled_predict_step(%arg0) : (tensor<1x540x960x1xf32>) -> tensor<1x534x954x1xf32>
    return %0 : tensor<1x534x954x1xf32>
  }
  func.func private @compiled_predict_step(%arg0: tensor<1x540x960x1xf32>) -> tensor<1x534x954x1xf32> {
    %cst = arith.constant dense<[[[1.000000e+00]], [[2.000000e+00]], [[3.000000e+00]], [[4.000000e+00]], [[5.000000e+00]]]> : tensor<5x1x1xf32>
    %cst_0 = arith.constant dense<[[[1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00], [5.000000e+00]]]> : tensor<1x5x1xf32>
    %cst_1 = arith.constant dense<[[[1.000000e+00], [2.000000e+00], [3.000000e+00]]]> : tensor<1x3x1xf32>
    %cst_2 = arith.constant dense<[[[1.000000e+00]], [[2.000000e+00]], [[3.000000e+00]]]> : tensor<3x1x1xf32>
    %cst_3 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x540x958x1xf32>
    %1 = linalg.fill ins(%cst_3 : f32) outs(%0 : tensor<1x540x958x1xf32>) -> tensor<1x540x958x1xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %cst_1 : tensor<1x540x960x1xf32>, tensor<1x3x1xf32>) outs(%1 : tensor<1x540x958x1xf32>) attrs =  {separable_blur = true} {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x540x958x1xf32>
    %3 = tensor.empty() : tensor<1x538x958x1xf32>
    %4 = linalg.fill ins(%cst_3 : f32) outs(%3 : tensor<1x538x958x1xf32>) -> tensor<1x538x958x1xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%2, %cst_2 : tensor<1x540x958x1xf32>, tensor<3x1x1xf32>) outs(%4 : tensor<1x538x958x1xf32>) {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x538x958x1xf32>
    %6 = tensor.empty() : tensor<1x538x954x1xf32>
    %7 = linalg.fill ins(%cst_3 : f32) outs(%6 : tensor<1x538x954x1xf32>) -> tensor<1x538x954x1xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%5, %cst_0 : tensor<1x538x958x1xf32>, tensor<1x5x1xf32>) outs(%7 : tensor<1x538x954x1xf32>) attrs =  {separable_blur = true} {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x538x954x1xf32>
    %9 = tensor.empty() : tensor<1x534x954x1xf32>
    %10 = linalg.fill ins(%cst_3 : f32) outs(%9 : tensor<1x534x954x1xf32>) -> tensor<1x534x954x1xf32>
    %11 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%8, %cst : tensor<1x538x954x1xf32>, tensor<5x1x1xf32>) outs(%10 : tensor<1x534x954x1xf32>) {
    ^bb0(%in: f32, %in_4: f32, %out: f32):
      %12 = arith.mulf %in, %in_4 : f32
      %13 = arith.addf %12, %out : f32
      linalg.yield %13 : f32
    } -> tensor<1x534x954x1xf32>
    return %11 : tensor<1x534x954x1xf32>
  }
}

(iree collapses the unit dims hence my transform spec has adjusted dims)

This successfully creates a dispatch region as I wanted but I cant merge more ops into it due to invalid handle

error: op uses a handle invalidated by a previously executed transform op
        %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %root into %region_op : (!transform.any_op, !transform.any_op) -> !transform.any_op

Is there an obvious way to get the handles to more ops in the dag that is matched by cast_compatible_dag_from_root ?

Ah maybe I solved my own problem, this is my new spec does this look sane?

#map = affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0 + d2, d1)>

module  attributes { transform.with_named_sequence }  {

    transform.named_sequence @match_separable(%root: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_value) {
    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
          ^bb0(%arg0: tensor<540x960xf32>, %cst_1: tensor<3xf32>, %cst_2: tensor<3xf32>):
              %cst_3 = arith.constant 0.000000e+00 : f32
              %0 = tensor.empty() : tensor<540x958xf32>
              %1 = linalg.fill ins(%cst_3 : f32) outs(%0 : tensor<540x958xf32>) -> tensor<540x958xf32>
              %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %cst_1 : tensor<540x960xf32>, tensor<3xf32>) outs(%1 : tensor<540x958xf32>) {
              ^bb0(%in: f32, %in_4: f32, %out: f32):
                %12 = arith.mulf %in, %in_4 : f32
                %13 = arith.addf %12, %out : f32
                linalg.yield %13 : f32
              } -> tensor<540x958xf32>
              %3 = tensor.empty() : tensor<538x958xf32>
              %4 = linalg.fill ins(%cst_3 : f32) outs(%3 : tensor<538x958xf32>) -> tensor<538x958xf32>
              %5 = linalg.generic {indexing_maps = [#map3, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2, %cst_2 : tensor<540x958xf32>, tensor<3xf32>) outs(%4 : tensor<538x958xf32>) {
              ^bb0(%in: f32, %in_4: f32, %out: f32):
                %12 = arith.mulf %in, %in_4 : f32
                %13 = arith.addf %12, %out : f32
                linalg.yield %13 : f32
              } -> tensor<538x958xf32>
        } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
        transform.yield %ins, %outs : !transform.any_value, !transform.any_value
    }

    transform.named_sequence @create_dispatch_region(%ins: !transform.any_value {transform.readonly},
                                                 %out: !transform.any_value {transform.readonly}) {
        %root = transform.get_defining_op %out : (!transform.any_value) -> !transform.any_op
        %module = transform.util.get_nearest_symbol_table %root : (!transform.any_op) -> !transform.any_op
        //%prev = transform.get_parent_op %root :(!transform.any_op) -> !transform.any_op
        %prev_value = transform.get_operand %root[0] : (!transform.any_op) -> !transform.any_value
        %prev = transform.get_defining_op %prev_value : (!transform.any_value) -> !transform.any_op
        transform.print %prev  {name = "prev"}: !transform.any_op
        %region_op = transform.iree.wrap_in_dispatch_region %root { generateWorkload = false } : (!transform.any_op) -> !transform.any_op

        %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %prev into %region_op : (!transform.any_op, !transform.any_op) -> !transform.any_op
        transform.yield
    }


    transform.named_sequence @__transform_main(%module: !transform.any_op) {
        %funcs = transform.structured.match ops{["util.func"]} in %module : (!transform.any_op) -> !transform.any_op
            // For each function in the module, run the matcher on all contained
            // operations.
            transform.foreach %funcs : !transform.any_op {
              ^bb1(%func: !transform.any_op):
                transform.foreach_match in %func
                    // <matcher name> -> <rewriter name>
                    // Multiple matcher-action pairs can be specified comma separated,
                    // here we are only doing a single kind of match and replace.
                    @match_separable -> @create_dispatch_region
                  : (!transform.any_op) -> (!transform.any_op)
            }
            // Cleanup leftover dead code; cast_and_call does not do replacement, only
            // rewires uses.
            transform.yield
    }
}