Consider the following MLIR implementation of the out = numpy.where(cnd, in1, in2)
operator:
func.func @where(%cond: memref<?xi1>, %arg1: memref<?xf64>, %arg2: memref<?xf64>, %arg3: memref<?xf64>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%0 = memref.dim %cond, %c0 : memref<?xi1>
affine.for %arg8 = #map(%c0) to #map(%0) {
%1 = affine.load %arg0[%arg8] : memref<?xi1>
%2 = scf.if %1 -> (f64) {
%3 = affine.load %arg1[%arg8] : memref<?xf64>
scf.yield %3 : f64
} else {
%3 = affine.load %arg2[%arg8] : memref<?xf64>
scf.yield %3 : f64
}
affine.store %2, %arg3[%arg8] : memref<?xf64>
}
return
}
First question – is this the best way to write the above computation? I see that the affine dialect’s if
operation is for integer set containment, so it seems to me like the scf.if
is the only other option.
Given this implementation, I’m not able to fuse this loop with an scf.if
into other affine loops that produce the input arrays in1
or in2
with the affine loop fusion pass. Before I decide to spend time trying to add support for this into the pass, I wanted to ask if this a fundamental limitation of the affine loop fusion pass, since it seems to require reasoning about dialects other than affine
? I could see that as being an argument, but on the other hand, the way that the scf.if
is being used is easy to reason about: centered accesses, affine accesses in either branch etc.