ConvertTrivialIfToSelect canonicalization generates inefficient IR

The canonicalization pattern ConvertTrivialIfToSelect can emit inefficient code since this patch: ⚙ D121943 [MLIR][SCF] Create selects from if yield results which are not defined in the body

The original logic of this pattern was to flatten empty scf.IfOp into select. However the patch mentioned above now converts any IfOp where the value yield are defined outside the the region even if the IfOp cannot later be flattened. In case the IfOp is not later flattened we end up with extra select operations making the IR strictly worse.

I think the changes in :gear: D121943 [MLIR][SCF] Create selects from if yield results which are not defined in the body make this transformation not a canonicalization anymore as it is not always beneficial and it is hard to reverse.

Ex:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = scf.if %cond -> (index, index) {
    call @side_effect() : () -> ()
    ...
    scf.yield %c0 : index
   } else {
    call @side_effect_2() : () -> ()
    ...
    scf.yield %c1 : index
  }
  return %0 : index
}

would be converted to:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %s = arith.select %cond, %c0, %c1 : index
  %0 = scf.if %cond -> (index, index) {
    call @side_effect() : () -> ()
   ...
   } scf.else {
    call @side_effect_2() : () -> ()
    ...
   }
  return %0 : index
}

I would like to revert to the original behavior of this pattern. If we still want the more aggressive version we can have it as an opt-in pattern.

@wsmoses could you give more details on the use case you had for this patch?

cc: @wsmoses @mehdi_amini @Hardcode84 @ftynse

This bug was found and debugged by @pawel.szczerbuk while debugging performance on Triton.

1 Like

Nit: “not always beneficial” isn’t the way I think about canonicalization, whatever side you canonicalize you very often find a target that will claim a regression!

In this particular case, I agree with you that this transformation does not seem like an obviously great canonicalization to me, the original patch does not motivate it well unfortunately.

Why do you feel it is hard to reverse though? Seems to me you could fairly easily write the inverse transformation to apply pre-lowering: just walk the block and collect users per condition value?
More importantly you likely should implement this inverse transformation even if/when we revert this canonicalization! The select operation could come from another if in the same block that was fully flattened, that you could want to “sink” into the second if, e.g.:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = scf.if %cond -> (index, index) {
    scf.yield %c0 : index
   } else {
    scf.yield %c1 : index
  }
  scf.if %cond {
    call @side_effect() : () -> ()
   } else {
    call @side_effect_2() : () -> ()
  }
  return %0 : index
}

Canonicalization flattening the first scf.if:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = arith.select %cond, %c0, %c1 : index
  scf.if %cond {
    call @side_effect() : () -> ()
   } else {
    call @side_effect_2() : () -> ()
  }
  return %0 : index
}

Which you could then want to transform before lowering:

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = scf.if %cond -> (index, index) {
    call @side_effect() : () -> ()
    scf.yield %c0 : index
   } else {
    call @side_effect_2() : () -> ()
    scf.yield %c1 : index
  }
  return %0 : index
}

I don’t know usecase for this specific transformation, but I can imagine it can open opportunities for arith.select canonicalizations/foldings , which are not available for scf.if

yes fair enough. It still does feel like a very intuitively canonical form though.

I wonder what kind of transformation/folding would that be.

1 Like

Here is somewhat artificial example:

  %poison = ub.poison : i32
  %0 = scf.if %cond -> i32 {
    func.call @side_effect() : () -> ()
    scf.yield %arg : i32
   } else {
    func.call @side_effect_2() : () -> ()
    scf.yield %poison  : i32
  }

Where scf.yields canonicalized to arith.select %cond %arg, %poison which is then folded to just %arg

Yeah this (and ironically other optimizations enabled) by it were added a few years ago when working on C++ optimizations from Polygeist.

If memory serves, the reasoning was as follows:

Without a multi-region escaping break/return/continue etc we represent the need to escape scope with a boolean flag, and create scf.if’s around all operations as to whether or not they should now continue executing (depending on whether the flag has been set).

A mem2reg results in the if statements having yields, that would be nice to simplify here (and enabled downstream ops).

Similarly, this comes up in a bunch of C++ generated code where booleans are actually represented with i8’s not i1’s. Doing this folding enables exactly the sort of later exttosi folding/etc down the line.

Happy to hop on a call at any point and talk through design decisions that can make everything work well together!

Hello @wsmoses, thanks for highlighting the background rationale. We are getting back to this after a break. It would be great to talk through the issue and possible solutions! Can you share time slots that would work for you to jump on a call? Thanks!

If you feel that a call is necessary and forums/chats don’t provide adequate bandwidth for discussion, MLIR ODM should be the place for such a synchronous discussions.

1 Like

We are currently introducing the pass that reverts the canonicalization pattern. I wonder however if this should be anyhow addressed in the upstream. Would we want the reversal path upstreamed, or perhaps have the more aggressive behavior of ConvertTrivialIfToSelect enabled on demand?

Even if reverting it is possible I still think this is probably not the right canonical form.
It would be great to discuss it with @wsmoses and others. Any if the next ODM works if @wsmoses can attend.
I can add it to the list of topics if it helps.

1 Like

Oh sorry missed the bump. Yeah I’m free Friday PT before 1pm if that works for you.

We can also email offline wsmoses at illinois dot edu to find a time (I’m usually faster at responding to email than discourse).

1 Like

Also putting it out there, I don’t understand why you don’t think it canonical if you would be able to expand on that.

E.g. part of why we did this was to enable better constant propagation down the line (like mentioned above).

But also purely from an implementation perspective, removing a potential conditional branch (somewhat implied by an scf.if) and replacing it with a select feels more canonical to me?

removing a scf.if with a trivial region and replacing it by a select does feel like a canonicalization but the problem is when the scf.if cannot be removed because of other operations then you end up with extra selects in addition to the If branch. There are cases where user will want to explicitly branch and pulling part of the code out of the branch will loose this information.
for instance we have a program like:

if (c) {
  store a0, ptr0
  a0 = init
  a1 = init
  ...
}

this generate a large number of selects while keeping the branch operation.

The example of folding given above is interesting and it is true that it would need to be added to IfOp otherwise. That could be an argument for aggressively converting to select.
The major downside in my opinion is to not be able to keep the original program form where branching is explicitly done to avoid computations.

Anyway let’s try to quickly chat on Friday

1 Like

Feel free to say it’ll just be easier to explain on a call (and please don’t mistake my questions for resistance),

But I think that is my core question: in your use case why is having more selects with a partial remaining if worse than having an if with more results. I need to double check the code but I think this PR only moves things out of yields so it shouldn’t change any opportunities within the if. Similarly this wouldn’t impact any side effecting operation.

I found historically the former to be better for downstream optimizations. For example one may see both sides are a constant and be able to change it to an i1 to i8 ext— and then a later if statement with this same structure with a condition on the i8 may be able to even fused into the remaining partial if! I think generally reducing the level of nested regions (as is the case here) can often be immensely helpful in MLIR, due to complications analysis/optimizations have with it, in addition to simply removing branches/potential divergence being general nice for end code.

That said it does mean that if you had any code specialized to ifs you also will want to extend it to selects as well. If I recall correctly around the same time as this PR I also made a bunch of select optimizations as well (eg for conditionals which were really the bread and butter of a lot of performance / simplifications / analysis downstream).

If your use case is requiring code to stay together for other reasons, happy to also discuss ways that may be more natural guarantees of that within MLIR for that rather than attempt to prevent all optimizations from performing valid transformations that fall outside of your assumptions. For example collaborators and I have played around with code choice regions that denote potentially differing optimization choices for GPU programs, and operations which semantically ensure only valid transformations (within a spec) can occur from aliasing behavior.

@wsmoses you are right that the transformation is pulling things out of yield only, but this was still a problem in our case. The case that we found problematic was this:

cst = 0
a' = if (c) {
  ...
  yield (cst)
} else {
  yield (a)
}

which got converted to:

a' = select(c, cst, a)
if (c) {
...
}

In our case a spans multiple values and it’s initialization was put in the if to save computation, as @ThomasRaoux mentioned. This initialization was then folded into a yield.

That code should check that the select hoisting only occurs if the yielded value is defined outside the if iirc. So it shouldn’t cause something to span any more regions than it already has, right?

1 Like

Sorry, I should have been more clear. The issue is not that any of the arithmetic was moved out of the if. Rather that in the final codegen we end up with dozens of select instructions before the if, that could be all skipped using uniform jump (much cheaper on this specific target).

Right, but that alone does not necessarily instructive of whether the canonicalization is OK or not. It is expected that during lowering you have to do transformation that aren’t canonical, so the question here is whether a codegen preparation pass could, starting from the canonical form, easily turn the IR into the optimal form for codegen on your target?

btw @ThomasRaoux and @pawel.szczerbuk I didn’t see an email/discourse post/message confirming a time for today, so I’m assuming that’s not happening.

Happy to talk more, just let me know! (feel free to post, discourse DM or email wsmoses at illinois dot edu)

Yeah I think ++ medhi here. Each target hardware (or perhaps even target which isn’t hardware) will have different cost models for what is expensive, so at the end of the day getting efficient code will require some amount of target-specific transformation.

In my mind the things that make a canonicalization useful are:

  1. is it by some definition canonical [e.g. can’t infinitely recur on itself]
  2. does it make things easier for other transformations to apply and/or in some definition simpler.

I think both of these apply here, but of course the latter one really depends on what downstream optiizations you have. So far, the ones I mention above feel sufficiently compelling to save, especially since your use case feels like it should be part of target lowering.