Can affine dialect fuse the element-wise pattern (add-concat-add) into a single loop?

I am new to the affine dialect and have a question about the affine-loop-fusion pass.

It seems that the affine is a general mechanism to conduct loop transformation. Can the affine dialect be extended to fuse some element-wise ops who have shape-changed op in its middle, like slice/concat? How? Any suggestion is appreciated.

The examnple is below:

module  {
  func @static_add_concat_add(%arg0: memref<5x10xf32>, %arg1: memref<5x10xf32>) -> memref<10x10xf32> {
    %0 = memref.alloc() : memref<1x1xf32>
    %1 = memref.alloc() : memref<10x10xf32>
    affine.for %arg2 = 0 to 5 {
      affine.for %arg3 = 0 to 10 {
        %3 = affine.load %arg0[%arg2, %arg3] : memref<5x10xf32>
        %4 = affine.load %arg1[%arg2, %arg3] : memref<5x10xf32>
        %5 = arith.addf %3, %4 : f32
        affine.store %5, %0[0, 0] : memref<1x1xf32>
        %6 = affine.load %0[0, 0] : memref<1x1xf32>
        affine.store %6, %1[%arg2, %arg3] : memref<10x10xf32>
      }
    }
    affine.for %arg2 = 0 to 5 {
      affine.for %arg3 = 0 to 10 {
        %3 = affine.load %arg1[%arg2, %arg3] : memref<5x10xf32>
        affine.store %3, %1[%arg2 + 5, %arg3] : memref<10x10xf32>
      }
    }
    %2 = memref.alloc() : memref<10x10xf32>
    affine.for %arg2 = 0 to 10 {
      affine.for %arg3 = 0 to 10 {
        %3 = affine.load %1[%arg2, %arg3] : memref<10x10xf32>
        %4 = arith.addf %3, %3 : f32
        affine.store %4, %2[%arg2, %arg3] : memref<10x10xf32>
      }
    }
    return %2 : memref<10x10xf32>
  }
}

I suppose this is a very interesting topic about affine-fusion consideration on current MLIR. This pattern is a very common pattern on popular ML networks.
The IR of example above has already fused the first add op and half part of concat op together due to they have the RAW dependence:

    affine.for %arg2 = 0 to 5 {
      affine.for %arg3 = 0 to 10 {
        %3 = affine.load %arg0[%arg2, %arg3] : memref<5x10xf32>
        %4 = affine.load %arg1[%arg2, %arg3] : memref<5x10xf32>
        %5 = arith.addf %3, %4 : f32
        affine.store %5, %0[0, 0] : memref<1x1xf32>
        %6 = affine.load %0[0, 0] : memref<1x1xf32>
        affine.store %6, %1[%arg2, %arg3] : memref<10x10xf32>
      }
    }

The second part of concat IR has no dependence with the first part. I think whether to fuse it with the above one is based on hardware target. For some DSA target, maybe no fuse is right. And for gpu target, we suggest to fuse together. Because it can reduce one kernel launch. For current affine-loop-fusion pass, it can fuses these two part together with no shape-changed.

    affine.for %arg2 = 0 to 5 {
      affine.for %arg3 = 0 to 10 {
        %3 = affine.load %arg1[%arg2, %arg3] : memref<5x10xf32>
        affine.store %3, %1[%arg2 + 5, %arg3] : memref<10x10xf32>
      }
    }

The last part of Add IR has dependence with the concat op. For gpu target, it is better to fuse together to get one kernel. current affine-loop-fusion pass do not have the ability due to some conservative fusion strategy. But some scheduler such as isl scheduler can fuse these three op together to generate one kernel by taking some aggressive fusion strategy.

    affine.for %arg2 = 0 to 10 {
      affine.for %arg3 = 0 to 10 {
      ... ...
      }
    }

Maybe MLIR Community already have the plan to enhance affine-loop-fusion pass and to open more fusion option to users, or develop more fusion pass for special fusion requirement. :slight_smile: