Fusing Convolution with Relu (Conv + Relu -> ConvRelu ) in Linalg

Seems Linalg fuses only the elementwise operators (add/mul/sub etc), inside linalg.generic

What needs to be done to fuse a convolution followed by a Relu.

I have this TF mlir obtained from graphdef

module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}}  {
  func @main() -> tensor<*xf32> {
    %cst = "tf.Const"() {device = "", value = dense<[[[[0.61719793], [0.793435096], [0.121170968], [0.573920846], [0.445487499], [0.226183072], [0.973448157], [0.851443469], [0.565707207], [5.227910e-01]], [[0.00731075834], [0.557578623], [0.707973182], [0.959501981], [0.358185142], [0.699126303], [0.466926485], [0.413297445], [0.0725673139], [0.921178698]], [[0.30189532], [0.731407762], [0.370444775], [0.848782122], [0.871007978], [0.748299241], [0.102783702], [0.551843107], [0.816904246], [0.0332700573]], [[0.343397468], [0.982405781], [0.0607045554], [0.541470528], [0.726823389], [0.808600127], [0.256532729], [0.252898484], [0.110048898], [0.478431761]], [[0.979851245], [0.111159958], [0.724086046], [0.982171118], [0.211429045], [0.678284585], [0.0563224852], [0.837513744], [0.657312452], [0.536515653]], [[0.505726278], [0.648696065], [0.999981284], [0.00183360698], [0.745425224], [0.40943867], [0.333478332], [0.644900322], [0.19442305], [0.178031594]], [[0.681895732], [0.0306128804], [0.390298098], [0.893880724], [0.859292924], [0.445982367], [0.303117335], [0.769601822], [0.510563135], [0.387194842]], [[0.426417202], [0.49324289], [0.107445173], [0.790808141], [0.891636371], [0.0934373661], [0.42853874], [0.835353791], [0.0698968098], [0.611316978]], [[0.999428808], [0.938870192], [0.778351902], [0.973116576], [0.0702748299], [0.479627848], [0.455716878], [0.779465734], [0.614541173], [0.207848176]], [[0.847904145], [0.130607277], [0.678844749], [0.03101651], [0.964369654], [0.600809455], [0.0234705787], [0.356526107], [0.636354804], [0.281212419]]]]> : tensor<1x10x10x1xf32>} : () -> tensor<1x10x10x1xf32>
    %0 = "tf.Relu"(%cst) {device = ""} : (tensor<1x10x10x1xf32>) -> tensor<*xf32>
    %1 = "tf.Relu"(%0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32>
    %2 = "tf.Relu"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32>
    %cst_0 = "tf.Const"() {device = "", value = dense<[[[[8.405740e-01]], [[0.107482761]], [[0.885160744]]], [[[0.879221558]], [[0.272046864]], [[0.219075441]]], [[[0.853731691]], [[0.786423742]], [[0.132054776]]]]> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32>
    %3 = "tf.Conv2D"(%cst, %cst_0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x10x10x1xf32>, tensor<3x3x1x1xf32>) -> tensor<*xf32>
    %4 = "tf.Conv2D"(%2, %cst_0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<3x3x1x1xf32>) -> tensor<*xf32>
    %5 = "tf.Relu"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32>
    return %5 : tensor<*xf32>
  }
}

using tf-opt getting to get hlo from TF

tensorflow/bazel-bin/tensorflow/compiler/mlir/tf-opt -print-ir-before-all -print-ir-after-all --tf-shape-inference --xla-legalize-tf --allow-unregistered-dialect

module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}}  {
  func @main() -> tensor<1x8x8x1xf32> {
    %0 = mhlo.constant dense<[[[[0.61719793], [0.793435096], [0.121170968], [0.573920846], [0.445487499], [0.226183072], [0.973448157], [0.851443469], [0.565707207], [5.227910e-01]], [[0.00731075834], [0.557578623], [0.707973182], [0.959501981], [0.358185142], [0.699126303], [0.466926485], [0.413297445], [0.0725673139], [0.921178698]], [[0.30189532], [0.731407762], [0.370444775], [0.848782122], [0.871007978], [0.748299241], [0.102783702], [0.551843107], [0.816904246], [0.0332700573]], [[0.343397468], [0.982405781], [0.0607045554], [0.541470528], [0.726823389], [0.808600127], [0.256532729], [0.252898484], [0.110048898], [0.478431761]], [[0.979851245], [0.111159958], [0.724086046], [0.982171118], [0.211429045], [0.678284585], [0.0563224852], [0.837513744], [0.657312452], [0.536515653]], [[0.505726278], [0.648696065], [0.999981284], [0.00183360698], [0.745425224], [0.40943867], [0.333478332], [0.644900322], [0.19442305], [0.178031594]], [[0.681895732], [0.0306128804], [0.390298098], [0.893880724], [0.859292924], [0.445982367], [0.303117335], [0.769601822], [0.510563135], [0.387194842]], [[0.426417202], [0.49324289], [0.107445173], [0.790808141], [0.891636371], [0.0934373661], [0.42853874], [0.835353791], [0.0698968098], [0.611316978]], [[0.999428808], [0.938870192], [0.778351902], [0.973116576], [0.0702748299], [0.479627848], [0.455716878], [0.779465734], [0.614541173], [0.207848176]], [[0.847904145], [0.130607277], [0.678844749], [0.03101651], [0.964369654], [0.600809455], [0.0234705787], [0.356526107], [0.636354804], [0.281212419]]]]> : tensor<1x10x10x1xf32>
    %1 = mhlo.constant dense<[[[[0.61719793], [0.793435096], [0.121170968], [0.573920846], [0.445487499], [0.226183072], [0.973448157], [0.851443469], [0.565707207], [5.227910e-01]], [[0.00731075834], [0.557578623], [0.707973182], [0.959501981], [0.358185142], [0.699126303], [0.466926485], [0.413297445], [0.0725673139], [0.921178698]], [[0.30189532], [0.731407762], [0.370444775], [0.848782122], [0.871007978], [0.748299241], [0.102783702], [0.551843107], [0.816904246], [0.0332700573]], [[0.343397468], [0.982405781], [0.0607045554], [0.541470528], [0.726823389], [0.808600127], [0.256532729], [0.252898484], [0.110048898], [0.478431761]], [[0.979851245], [0.111159958], [0.724086046], [0.982171118], [0.211429045], [0.678284585], [0.0563224852], [0.837513744], [0.657312452], [0.536515653]], [[0.505726278], [0.648696065], [0.999981284], [0.00183360698], [0.745425224], [0.40943867], [0.333478332], [0.644900322], [0.19442305], [0.178031594]], [[0.681895732], [0.0306128804], [0.390298098], [0.893880724], [0.859292924], [0.445982367], [0.303117335], [0.769601822], [0.510563135], [0.387194842]], [[0.426417202], [0.49324289], [0.107445173], [0.790808141], [0.891636371], [0.0934373661], [0.42853874], [0.835353791], [0.0698968098], [0.611316978]], [[0.999428808], [0.938870192], [0.778351902], [0.973116576], [0.0702748299], [0.479627848], [0.455716878], [0.779465734], [0.614541173], [0.207848176]], [[0.847904145], [0.130607277], [0.678844749], [0.03101651], [0.964369654], [0.600809455], [0.0234705787], [0.356526107], [0.636354804], [0.281212419]]]]> : tensor<1x10x10x1xf32>
    %2 = mhlo.constant dense<[[[[8.405740e-01]], [[0.107482761]], [[0.885160744]]], [[[0.879221558]], [[0.272046864]], [[0.219075441]]], [[[0.853731691]], [[0.786423742]], [[0.132054776]]]]> : tensor<3x3x1x1xf32>
    %3 = mhlo.constant dense<[[[[1.90640473], [3.10885549], [2.61861587], [3.33316278], [3.21350598], [2.67322469], [2.57898664], [2.94214082]], [[2.31208372], [3.28184414], [2.33634424], [3.75166821], [3.071720e+00], [2.73593426], [1.3530308], [2.25094724]], [[2.26233196], [3.19885707], [2.952003], [3.41474891], [2.54013252], [2.69861364], [1.98078716], [2.241010e+00]], [[2.57230949], [3.16195059], [2.65632319], [2.96052194], [2.31011176], [2.42324853], [1.57913232], [2.40759516]], [[2.97448158], [2.33455133], [3.09402394], [3.24060106], [2.26468229], [2.63028479], [2.16186285], [3.01813293]], [[2.8394711], [1.59332752], [3.10654688], [2.9501524], [2.80030727], [2.12112331], [2.14148068], [2.46980643]], [[3.1489687], [3.03721976], [3.12877035], [3.14621377], [2.43915606], [2.3412149], [2.4919939], [3.10810637]], [[2.72745037], [3.026016], [2.660321], [2.68727589], [2.72986293], [2.158930e+00], [1.64368474], [2.99080229]]]]> : tensor<1x8x8x1xf32>
    %4 = mhlo.convolution(%1, %2) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x10x10x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x8x8x1xf32>
    %5 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %6 = shape.shape_of %5 : tensor<f32> -> tensor<0xindex>
    %7 = shape.shape_of %4 : tensor<1x8x8x1xf32> -> tensor<4xindex>
    %8 = shape.cstr_broadcastable %6, %7 : tensor<0xindex>, tensor<4xindex>
    %9 = shape.assuming %8 -> (tensor<1x8x8x1xf32>) {
      %10 = shape.const_shape [] : tensor<0xindex>
      %11 = shape.const_shape [1, 8, 8, 1] : tensor<4xindex>
      %12 = shape.const_shape [1, 8, 8, 1] : tensor<4xindex>
      %13 = "mhlo.dynamic_broadcast_in_dim"(%5, %12) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xindex>) -> tensor<1x8x8x1xf32>
      %14 = "mhlo.dynamic_broadcast_in_dim"(%4, %12) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x8x8x1xf32>, tensor<4xindex>) -> tensor<1x8x8x1xf32>
      %15 = mhlo.maximum %13, %14 : tensor<1x8x8x1xf32>
      shape.assuming_yield %15 : tensor<1x8x8x1xf32>
    }
    return %9 : tensor<1x8x8x1xf32>
  }
}

using mlir-hlo-opt to get linalg
/mlir-hlo/build/tools/mlir-hlo-opt/mlir-hlo-opt --hlo-legalize-to-linalg --canonicalize -cse --linalg-fusion-for-tensor-ops

#map0 = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}}  {
  func @main() -> tensor<1x8x8x1xf32> {
    %cst = constant 0x7FC00000 : f32
    %cst_0 = constant dense<[[[8.405740e-01], [0.107482761], [0.885160744]], [[0.879221558], [0.272046864], [0.219075441]], [[0.853731691], [0.786423742], [0.132054776]]]> : tensor<3x3x1xf32>
    %cst_1 = constant dense<[[[[0.61719793], [0.793435096], [0.121170968], [0.573920846], [0.445487499], [0.226183072], [0.973448157], [0.851443469], [0.565707207], [5.227910e-01]], [[0.00731075834], [0.557578623], [0.707973182], [0.959501981], [0.358185142], [0.699126303], [0.466926485], [0.413297445], [0.0725673139], [0.921178698]], [[0.30189532], [0.731407762], [0.370444775], [0.848782122], [0.871007978], [0.748299241], [0.102783702], [0.551843107], [0.816904246], [0.0332700573]], [[0.343397468], [0.982405781], [0.0607045554], [0.541470528], [0.726823389], [0.808600127], [0.256532729], [0.252898484], [0.110048898], [0.478431761]], [[0.979851245], [0.111159958], [0.724086046], [0.982171118], [0.211429045], [0.678284585], [0.0563224852], [0.837513744], [0.657312452], [0.536515653]], [[0.505726278], [0.648696065], [0.999981284], [0.00183360698], [0.745425224], [0.40943867], [0.333478332], [0.644900322], [0.19442305], [0.178031594]], [[0.681895732], [0.0306128804], [0.390298098], [0.893880724], [0.859292924], [0.445982367], [0.303117335], [0.769601822], [0.510563135], [0.387194842]], [[0.426417202], [0.49324289], [0.107445173], [0.790808141], [0.891636371], [0.0934373661], [0.42853874], [0.835353791], [0.0698968098], [0.611316978]], [[0.999428808], [0.938870192], [0.778351902], [0.973116576], [0.0702748299], [0.479627848], [0.455716878], [0.779465734], [0.614541173], [0.207848176]], [[0.847904145], [0.130607277], [0.678844749], [0.03101651], [0.964369654], [0.600809455], [0.0234705787], [0.356526107], [0.636354804], [0.281212419]]]]> : tensor<1x10x10x1xf32>
    %cst_2 = constant 0.000000e+00 : f32
    %0 = linalg.init_tensor [1, 8, 8, 1] : tensor<1x8x8x1xf32>
    %1 = linalg.fill(%cst_2, %0) : f32, tensor<1x8x8x1xf32> -> tensor<1x8x8x1xf32> 
    %2 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%cst_1, %cst_0 : tensor<1x10x10x1xf32>, tensor<3x3x1xf32>) outs(%1 : tensor<1x8x8x1xf32>) -> tensor<1x8x8x1xf32>
    %3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1x8x8x1xf32>) outs(%0 : tensor<1x8x8x1xf32>) {
    ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
      %4 = cmpf ogt, %cst_2, %arg0 : f32
      %5 = select %4, %cst_2, %arg0 : f32
      %6 = cmpf uno, %cst_2, %arg0 : f32
      %7 = select %6, %cst, %5 : f32
      linalg.yield %7 : f32
    } -> tensor<1x8x8x1xf32>
    return %3 : tensor<1x8x8x1xf32>
  }
}

What needs to be done to get to fuse these conv and the relu
%2 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
%3 = linalg.generic {cmp, select}

TensorFlow seems to fuse this Conv+BiasAdd+Relu into a single tf._FusedConv2D

// CHECK-LABEL: conv2DBiasAdd_reluActivation
func @conv2DBiasAdd_reluActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
 // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
 // CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
 // CHECK: return %[[VAL_1]]
 %0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
 %1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
 %2 = "tf.Relu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
 %3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
 return %3 : tensor<*xf32>
}

using tensorflow/bazel-bin/tensorflow/compiler/mlir/tf-opt -tf-fused-kernel-matcher

module  {
  func @conv2DBiasAdd_reluActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> tensor<*xf32> {
    %0 = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
    %1 = "tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32>
    return %1 : tensor<*xf32>
  }
}

Is it possible to have the fusion similar to what TF dialect is providing but at the Linalg level what else needs to get it done?

I can only talk about fusing convolution with generic op at Linalg level.

In general you would fuse such operations using tile and fuse, i.e. tile the consumer (in this case the linalg.generic and generate the producer tile (in this case linalg.*conv* needed for executing a tile of the consumer in-place. There are examples of this on linalg.matmul operations on memref types here. The pass that implements the tile and fuse is here. This pass uses the tile + fuse pattern that is defined here.

The test pass above is only for testing purposes. The pattern itself though is generic and works with all named ops. I changes the pass above to add this pattern

 patterns.add<LinalgTileAndFusePattern<GenericOp>>(
      context, dependenceGraph,
      LinalgTilingOptions().setTileSizes({16, 32, 64}).setLoopType(LoopType),
      LinalgFusionOptions().setIndicesToFuse({0}),
      LinalgTransformationFilter(
          Identifier::get("conv2d_fusion", context),
          Identifier::get("after_conv2d_fusion", context)),
      LinalgTransformationFilter(
          ArrayRef<Identifier>(),
          Identifier::get("after_conv2d_fusion_producer", context)),
      LinalgTransformationFilter(
          ArrayRef<Identifier>(),
          Identifier::get("after_conv2d_fusion_original", context)));

Note that this pattern is specified to apply on the linalg.generic operation. The way the pattern works is that you tile the consumer first (so thats the anchor to the pattern), and then the producer at operand 0 (the linalg.*_conv op) gets fused with it. For this pattern to trigger it requires the marker conv2d_fusion on the linalg.generic. With this pattern, the following code

$ cat conv_fusion.mlir
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @main(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?xf32>,
    %arg2 : tensor<?x?x?x?xf32>, %arg3 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %c2 = constant 2 : index
  %c3 = constant 3 : index
  %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
      ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
  %1 = tensor.dim %arg3, %c0 : tensor<?x?x?x?xf32>
  %2 = tensor.dim %arg3, %c1 : tensor<?x?x?x?xf32>
  %3 = tensor.dim %arg3, %c2 : tensor<?x?x?x?xf32>
  %4 = tensor.dim %arg3, %c3 : tensor<?x?x?x?xf32>
  %5 = linalg.init_tensor [%1, %2, %3, %4] : tensor<?x?x?x?xf32>
  %6 = linalg.generic
      {__internal_linalg_transform__ = "conv2d_fusion",
       indexing_maps = [#map, #map, #map],
       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
      ins(%0, %arg3 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
      outs(%5 : tensor<?x?x?x?xf32>) {
      ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32):
        %7 = addf %arg4, %arg5 : f32
        linalg.yield %7 : f32
      } -> tensor<?x?x?x?xf32>
  return %6 : tensor<?x?x?x?xf32>
}

with this command

$ mlir-opt -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file ~/iree_public/scratchspace/conv_fusion.mlir -canonicalize -cse

gets tiled and fused to

#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)>
#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)>
#map3 = affine_map<(d0, d1) -> (16, d0 - d1)>
#map4 = affine_map<(d0, d1) -> (32, d0 - d1)>
#map5 = affine_map<(d0, d1) -> (64, d0 - d1)>
#map6 = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)>
#map7 = affine_map<(d0, d1)[s0, s1] -> (d1 + s0 - 1, -d0 + s1)>
#map8 = affine_map<()[s0, s1] -> (s0, s1)>
#map9 = affine_map<(d0)[s0, s1] -> (-d0 + s0, 32, -d0 + s1)>
#map10 = affine_map<(d0)[s0, s1] -> (-d0 + s0, 64, -d0 + s1)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module  {
  func @main(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>, %arg3: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
    %c3 = constant 3 : index
    %c2 = constant 2 : index
    %c1 = constant 1 : index
    %c0 = constant 0 : index
    %c64 = constant 64 : index
    %c32 = constant 32 : index
    %c16 = constant 16 : index
    %0 = tensor.dim %arg3, %c0 : tensor<?x?x?x?xf32>
    %1 = tensor.dim %arg3, %c1 : tensor<?x?x?x?xf32>
    %2 = tensor.dim %arg3, %c2 : tensor<?x?x?x?xf32>
    %3 = tensor.dim %arg3, %c3 : tensor<?x?x?x?xf32>
    %4 = linalg.init_tensor [%0, %1, %2, %3] : tensor<?x?x?x?xf32>
    %5 = tensor.dim %arg0, %c0 : tensor<?x?x?x?xf32>
    %6 = tensor.dim %arg2, %c1 : tensor<?x?x?x?xf32>
    %7 = tensor.dim %arg2, %c2 : tensor<?x?x?x?xf32>
    %8 = scf.for %arg4 = %c0 to %5 step %c16 iter_args(%arg5 = %4) -> (tensor<?x?x?x?xf32>) {
      %9 = scf.for %arg6 = %c0 to %6 step %c32 iter_args(%arg7 = %arg5) -> (tensor<?x?x?x?xf32>) {
        %10 = scf.for %arg8 = %c0 to %7 step %c64 iter_args(%arg9 = %arg7) -> (tensor<?x?x?x?xf32>) {
          %11 = affine.min #map0(%arg6)[%6]
          %12 = affine.min #map1(%arg8)[%7]
          %13 = tensor.dim %arg0, %c3 : tensor<?x?x?x?xf32>
          %14 = affine.min #map2(%arg4)[%0]
          %15 = affine.min #map0(%arg6)[%1]
          %16 = affine.min #map1(%arg8)[%2]
          %17 = tensor.extract_slice %arg3[%arg4, %arg6, %arg8, 0] [%14, %15, %16, %3] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
          %18 = tensor.dim %arg9, %c0 : tensor<?x?x?x?xf32>
          %19 = affine.min #map3(%18, %arg4)
          %20 = tensor.dim %arg9, %c1 : tensor<?x?x?x?xf32>
          %21 = affine.min #map4(%20, %arg6)
          %22 = tensor.dim %arg9, %c2 : tensor<?x?x?x?xf32>
          %23 = affine.min #map5(%22, %arg8)
          %24 = tensor.dim %arg9, %c3 : tensor<?x?x?x?xf32>
          %25 = tensor.extract_slice %arg9[%arg4, %arg6, %arg8, 0] [%19, %21, %23, %24] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
          %26 = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
          %27 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
          %28 = affine.min #map6(%arg4)[%5, %5]
          %29 = tensor.dim %arg0, %c1 : tensor<?x?x?x?xf32>
          %30 = affine.min #map7(%arg6, %11)[%26, %29]
          %31 = tensor.dim %arg0, %c2 : tensor<?x?x?x?xf32>
          %32 = affine.min #map7(%arg8, %12)[%27, %31]
          %33 = affine.min #map8()[%13, %13]
          %34 = tensor.extract_slice %arg0[%arg4, %arg6, %arg8, 0] [%28, %30, %32, %33] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
          %35 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
          %36 = affine.min #map8()[%13, %35]
          %37 = tensor.extract_slice %arg1[0, 0, 0] [%26, %27, %36] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
          %38 = tensor.dim %arg2, %c0 : tensor<?x?x?x?xf32>
          %39 = affine.min #map6(%arg4)[%38, %5]
          %40 = affine.min #map9(%arg6)[%6, %6]
          %41 = affine.min #map10(%arg8)[%7, %7]
          %42 = tensor.dim %arg2, %c3 : tensor<?x?x?x?xf32>
          %43 = affine.min #map8()[%13, %42]
          %44 = tensor.extract_slice %arg2[%arg4, %arg6, %arg8, 0] [%39, %40, %41, %43] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
          %45 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {__internal_linalg_transform__ = "after_conv2d_fusion_producer", dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%34, %37 : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%44 : tensor<?x?x?x\
?xf32>) -> tensor<?x?x?x?xf32>
          %46 = linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%45, %17 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%25 : tensor<?x?x?x?xf32>) attrs =  {__internal_linalg_transform__ = "after_co\
nv2d_fusion"} {
          ^bb0(%arg10: f32, %arg11: f32, %arg12: f32):  // no predecessors
            %48 = addf %arg10, %arg11 : f32
            linalg.yield %48 : f32
          } -> tensor<?x?x?x?xf32>
          %47 = tensor.insert_slice %46 into %arg9[%arg4, %arg6, %arg8, 0] [%19, %21, %23, %24] [1, 1, 1, 1] : tensor<?x?x?x?xf32> into tensor<?x?x?x?xf32>
          scf.yield %47 : tensor<?x?x?x?xf32>
        }
        scf.yield %10 : tensor<?x?x?x?xf32>
      }
      scf.yield %9 : tensor<?x?x?x?xf32>
    }
    return %8 : tensor<?x?x?x?xf32>
  }
}

The output is verbose because this is using dynamic shapes (I tried the static shape case, and that didnt work as I expected it to, something minor off there, I will look into this). But whats important is that there is a tile of convolution produced, the elementwise operation is performed and the tile is returned.

The above approach is using the pattern. But the core method used in the pattern tiles + fuses a sequence of operations here. You can use that directly using this command

$mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file conv_fusion.mlir

This produces the same output. The second pass does not use the pattern but the core method directly. So it doesnt require any marker to be set, etc. But it is also less configurable.

1 Like