ConvertTrivialIfToSelect canonicalization generates inefficient IR

The canonicalization pattern ConvertTrivialIfToSelect can emit inefficient code since this patch: ⚙ D121943 [MLIR][SCF] Create selects from if yield results which are not defined in the body

The original logic of this pattern was to flatten empty scf.IfOp into select. However the patch mentioned above now converts any IfOp where the value yield are defined outside the the region even if the IfOp cannot later be flattened. In case the IfOp is not later flattened we end up with extra select operations making the IR strictly worse.

I think the changes in :gear: D121943 [MLIR][SCF] Create selects from if yield results which are not defined in the body make this transformation not a canonicalization anymore as it is not always beneficial and it is hard to reverse.

Ex:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = scf.if %cond -> (index, index) {
    call @side_effect() : () -> ()
    ...
    scf.yield %c0 : index
   } else {
    call @side_effect_2() : () -> ()
    ...
    scf.yield %c1 : index
  }
  return %0 : index
}

would be converted to:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %s = arith.select %cond, %c0, %c1 : index
  %0 = scf.if %cond -> (index, index) {
    call @side_effect() : () -> ()
   ...
   } scf.else {
    call @side_effect_2() : () -> ()
    ...
   }
  return %0 : index
}

I would like to revert to the original behavior of this pattern. If we still want the more aggressive version we can have it as an opt-in pattern.

@wsmoses could you give more details on the use case you had for this patch?

cc: @wsmoses @mehdi_amini @Hardcode84 @ftynse

This bug was found and debugged by @pawel.szczerbuk while debugging performance on Triton.

1 Like

Nit: “not always beneficial” isn’t the way I think about canonicalization, whatever side you canonicalize you very often find a target that will claim a regression!

In this particular case, I agree with you that this transformation does not seem like an obviously great canonicalization to me, the original patch does not motivate it well unfortunately.

Why do you feel it is hard to reverse though? Seems to me you could fairly easily write the inverse transformation to apply pre-lowering: just walk the block and collect users per condition value?
More importantly you likely should implement this inverse transformation even if/when we revert this canonicalization! The select operation could come from another if in the same block that was fully flattened, that you could want to “sink” into the second if, e.g.:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = scf.if %cond -> (index, index) {
    scf.yield %c0 : index
   } else {
    scf.yield %c1 : index
  }
  scf.if %cond {
    call @side_effect() : () -> ()
   } else {
    call @side_effect_2() : () -> ()
  }
  return %0 : index
}

Canonicalization flattening the first scf.if:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = arith.select %cond, %c0, %c1 : index
  scf.if %cond {
    call @side_effect() : () -> ()
   } else {
    call @side_effect_2() : () -> ()
  }
  return %0 : index
}

Which you could then want to transform before lowering:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = scf.if %cond -> (index, index) {
    call @side_effect() : () -> ()
    scf.yield %c0 : index
   } else {
    call @side_effect_2() : () -> ()
    scf.yield %c1 : index
  }
  return %0 : index
}

I don’t know usecase for this specific transformation, but I can imagine it can open opportunities for arith.select canonicalizations/foldings , which are not available for scf.if

yes fair enough. It still does feel like a very intuitively canonical form though.

I wonder what kind of transformation/folding would that be.

1 Like

Here is somewhat artificial example:

  %poison = ub.poison : i32
  %0 = scf.if %cond -> i32 {
    func.call @side_effect() : () -> ()
    scf.yield %arg : i32
   } else {
    func.call @side_effect_2() : () -> ()
    scf.yield %poison  : i32
  }

Where scf.yields canonicalized to arith.select %cond %arg, %poison which is then folded to just %arg

Yeah this (and ironically other optimizations enabled) by it were added a few years ago when working on C++ optimizations from Polygeist.

If memory serves, the reasoning was as follows:

Without a multi-region escaping break/return/continue etc we represent the need to escape scope with a boolean flag, and create scf.if’s around all operations as to whether or not they should now continue executing (depending on whether the flag has been set).

A mem2reg results in the if statements having yields, that would be nice to simplify here (and enabled downstream ops).

Similarly, this comes up in a bunch of C++ generated code where booleans are actually represented with i8’s not i1’s. Doing this folding enables exactly the sort of later exttosi folding/etc down the line.

Happy to hop on a call at any point and talk through design decisions that can make everything work well together!