Affine loop fusion with structured branching

Consider the following MLIR implementation of the out = numpy.where(cnd, in1, in2) operator:

  func.func @where(%cond: memref<?xi1>, %arg1: memref<?xf64>, %arg2: memref<?xf64>, %arg3: memref<?xf64>) attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %0 = memref.dim %cond, %c0 : memref<?xi1>
    affine.for %arg8 = #map(%c0) to #map(%0) {
      %1 = affine.load %arg0[%arg8] : memref<?xi1>
      %2 = scf.if %1 -> (f64) {
        %3 = affine.load %arg1[%arg8] : memref<?xf64>
        scf.yield %3 : f64
      } else {
        %3 = affine.load %arg2[%arg8] : memref<?xf64>
        scf.yield %3 : f64
      }
      affine.store %2, %arg3[%arg8] : memref<?xf64>
    }
    return
  }

First question – is this the best way to write the above computation? I see that the affine dialect’s if operation is for integer set containment, so it seems to me like the scf.if is the only other option.

Given this implementation, I’m not able to fuse this loop with an scf.if into other affine loops that produce the input arrays in1 or in2 with the affine loop fusion pass. Before I decide to spend time trying to add support for this into the pass, I wanted to ask if this a fundamental limitation of the affine loop fusion pass, since it seems to require reasoning about dialects other than affine? I could see that as being an argument, but on the other hand, the way that the scf.if is being used is easy to reason about: centered accesses, affine accesses in either branch etc.

I suspect that this is a bug at some level. The presence of an scf.if impacts fusion of loops that are unaffected by the loop with the branch:

  func.func @testing() {
    %alloc = memref.alloc() : memref<100xf64>
    %alloc_0 = memref.alloc() : memref<100xf64>
    %alloc_1 = memref.alloc() : memref<100xf64>
    %alloc_2 = memref.alloc() : memref<100xf64>
    %alloc_3 = memref.alloc() : memref<100xi1>
    affine.for %arg0 = 0 to 100 {
      %0 = affine.load %alloc[%arg0] : memref<100xf64>
      affine.store %0, %alloc_0[%arg0] : memref<100xf64>
    }
    affine.for %arg0 = 0 to 100 {
      %0 = affine.load %alloc_0[%arg0] : memref<100xf64>
      affine.store %0, %alloc_1[%arg0] : memref<100xf64>
    }
    affine.for %arg0 = 0 to 100 {
      %0 = affine.load %alloc_3[%arg0] : memref<100xi1>
      %1 = scf.if %0 -> (f64) {
        %2 = affine.load %alloc_1[%arg0] : memref<100xf64>
        scf.yield %2 : f64
      } else {
        %2 = affine.load %alloc_1[%arg0] : memref<100xf64>
        scf.yield %2 : f64
      }
      affine.store %1, %alloc_2[%arg0] : memref<100xf64>
    }
    return
  }

When applying the affine-loop-fusion pass on this example, even the first two for loops don’t get fused. However, when I remove the if from the final loop, then the pass succeeds in fusing all 3 loops together.

As an answer to the first part of this question, I found the arith.select instruction, which can be used as a ternary operator without short circuiting.

For the second question, the affine loop fusion analysis pass will exit early if it finds any contents within affine for loops that contain regions that aren’t affine loops or ifs.

1 Like