[RFC] Changing the loop pipeliner prologue/epilogue generation

Context

The loop pipeliner in scf currently allows a user to provide a schedule for a scf::for loop and will automatically generate the pipelined loop by emitting the prologue + new loop + epilogue.
This greatly simplify writing a software pipeline transformation as it allows separating finding the schedule from transforming the loop.

Current implementation

The current interface allows user to decide if the epilogue should be peeled out or not. When the epilogue is not peeled out the pipeliner will mask out operations so that not all the operations are executed in the last iterations of the loop. The masking is done by a lambda that needs to be provided by users since MLIR doesn’t have a standard way to mask.

Problem

  1. We cannot control how much we want to peel the epilogue and we cannot control whether we want to peel the prologue. It can be very useful to control those parameters with finer grain control. For instance we want to peel the epilogue only for 1 iteration if there are very few ops in the last stage.
  2. The pipeliner logic is more or less duplicated in the prologue and epilogue emission make it more likely to have bugs.

Proposal

Instead of emitting prologue then kernel then epilogue, we could emit the pipelined loop directly then peel out loop iterations for the prologue and epilogue. Of course we still want to filter out dead operations in the prologue and epilogue so to make it easy we add a new transient piplening.mask operation that controls when this op should or shouldn’t be executed. Then when we peel the prologue/epilogue we can just “fold” the mask operation that either becomes dead or becomes the original operation.

ex:

    scf.for %arg2 = %c0 to %c8 step %c1 {
      %0 = memref.load %arg0[%arg2]  : memref<?xf32>
      %1 = arith.addf %0, %cst : f32
      memref.store %1, %arg1[%arg2] : memref<?xf32>
    }

currently generates directly:

    %0 = memref.load %arg0[%c0] : memref<?xf32>
    %1 = memref.load %arg0[%c1] : memref<?xf32>
    %2 = arith.addf %0, %cst : f32
    %3:2 = scf.for %arg2 = %c0 to %c6 step %c1 iter_args(%arg3 = %1, %arg4 = %2) -> (f32, f32) {
      %5 = arith.addi %arg2, %c2 : index
      %6 = memref.load %arg0[%5] : memref<?xf32>
      %7 = arith.addf %arg3, %cst : f32
      memref.store %arg4, %arg1[%arg2] : memref<?xf32>
      scf.yield %6, %7 : f32, f32
    }
    %4 = arith.addf %3#0, %cst : f32
    memref.store %3#1, %arg1[%c6] : memref<?xf32>
    memref.store %4, %arg1[%c7] : memref<?xf32>

instead the first step would be to generate:

    %ini0 = ub.poison : f32
    %ini1 = ub.poison : f32
    %3:2 = scf.for %arg2 = %c-2 to %c8 step %c1 iter_args(%arg3 = %ini0, %arg4 = %ini1) -> (f32, f32) {
      %5 = arith.addi %arg2, %c2 : index
      %p0 = pipeline.predicate 0, %arg2, %c0, %c2, %c1
      %p1 = pipeline.predicate 1, %arg2, %c0, %c2, %c1
      %p2 = pipeline.predicate 2, %arg2, %c0, %c2, %c1
      %m6 = pipeline.predicate.mask %p0 { %6 = memref.load %arg0[%5] : memref<?xf32> }
      %m7 = pipeline.predicate.mask %p1 { %7 = arith.addf %arg3, %cst : f32 }
      pipeline.predicate.mask %p2 { memref.store %arg4, %arg1[%arg2] : memref<?xf32> }
      scf.yield %m6, %m7 : f32, f32
    }

at this stage we can either lower pipeline.predicate and pipeline.predicate.mask directly and we would get the loop with both epilogue and prologue not peeled or we can apply loop peeling then fold the pipeline ops. For instance if we peel one iteration at the end:

    %ini0 = ub.poison : f32
    %ini1 = ub.poison : f32
    %3:2 = scf.for %arg2 = %c-2 to %c7 step %c1 iter_args(%arg3 = %ini0, %arg4 = %ini1) -> (f32, f32) {
      %5 = arith.addi %arg2, %c2 : index
      %p0 = pipeline.predicate 0, %arg2, %c0, %c2, %c1
      %p1 = pipeline.predicate 1, %arg2, %c0, %c2, %c1
      %p2 = pipeline.predicate 2, %arg2, %c0, %c2, %c1
      %m6 = pipeline.predicate.mask %p0 { %6 = memref.load %arg0[%5] : memref<?xf32> }
      %m7 = pipeline.predicate.mask %p1 { %7 = arith.addf %arg3, %cst : f32 }
      pipeline.predicate.mask %p2 { memref.store %arg4, %arg1[%arg2] : memref<?xf32> }
      scf.yield %m6, %m7 : f32, f32
    }
    %5 = arith.addi %c7, %c2 : index
    %p0 = %false
    %p1 = %false
    %p2 = %true
    %m6 = pipeline.predicate.mask %p0 { %6 = memref.load %arg0[%5] : memref<?xf32> } // dead
    %m7 = pipeline.predicate.mask %p1 { %7 = arith.addf %3#0, %cst : f32 } // dead
    // pipeline.predicate.mask %p2 { memref.store %3#1, %arg1[%c7] : memref<?xf32> }
    memref.store %3#1, %arg1[%c7] : memref<?xf32>

This requires adding two transient ops. Those would only be used within the pipelining transformation.
One question I have is whether we have a precedent for transient ops and if adding those into SCF dialect would make sense?

implementation

This change wouldn’t require breaking the current interface we could just add extra parameters to control how much we want to peel the prologue/epilogue. We can then change the kernel lowering and implement a peeling helper. There is no probably no peeling helpers that can be used out of the box so I we might need a specific one but this is very little code.

I’m not sure if there are many users of the loop pipeliner downstream, it was originally written for IREE but not sure if it is still used there. It is used heavily in Triton right now. Having fine grain control on the peeling has become quite important but I wonder if other users have found other missing features.

Also I’m not sure when I’ll get to implementing this, if anybody is interested in contributing I’m very open to collaboration here. If there is low interest in the community I may start this downstream first.

cc: @pawel.szczerbuk @Mogball @mehdi_amini @antiagainst

1 Like

I like the idea, I agree mega-transform is hard to control. But honestly, new IR is quite intimidating :slight_smile:

Few questions to clarify my understanding

  1. The examples show %p0 = pipeline.predicate 0, %arg2, %c0, %c2, %c1

IIUC it checks this stage alive for this iteration. If yes, why do we need triple, wouldn’t be 1 op enough per iteration?

    %p0 = pipeline.predicate 0, %arg2, %c0, %c2, %c1
    %p1 = pipeline.predicate 1, %arg2, %c0, %c2, %c1
    %p2 = pipeline.predicate 2, %arg2, %c0, %c2, %c1

Can you elaborate its operands? What is 0 here?

  1. What pipeline.predicate.mask returns?

Does that return predicate? What does it return when the op is masked? poison?

  1. How do the transient ops vanish? I guess this is a separate pass.
    What happens if we have mixed ops in the loops like
scf.for {
    ...
    %m6 = pipeline.predicate.mask %p0 { %6 = memref.load %arg0[%5] : memref<?xf32> }
    %7 = arith.addf %arg3, %cst : f32 
    ...
}

Thanks for the response.

Interesting :slight_smile: I wasn’t expecting it to be, let’s see if we can make sense of it.

You mean having one op that would return 3 i1 values? Yes that would work as well.

The operands are the loop bounds, the 0 means it is for stage 0. It gets lowered to the current arithmetic we have in the pipeliner when predicating ops: llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp at 0edc8b59ab82c868cb76b5b7339916c21d0a35ee · llvm/llvm-project · GitHub

It returns either the yield value or poison based on the mask value

I was thinking a second phase in the same pass, in general the intermediate IR cannot be used for much besides peeling loops.

Should be completely fine, the transient ops should be well defined and have a correct lowering

What do you think? Does that make sense to you?

Aren’t all ops transient? :slight_smile: I think it could make sense as part of, could also be in scf.pipeline dialect (to avoid the more generic name - it reminded me to go look at the CIRCT one and see how their pipeline dialect looks). Do these ops (3 of them right?) make sense outside of an scf.for?

sorry for the slow response. I guess adding a dialect for this sounds like an overkill which is why I was considering adding it to SCF but I’m not well calibrated on whether this is the right thing to do or not.
It would be two ops, one of them would be quite specific to the pipeliner, it would basically map to a sequence of arithmetic op that are not expanded to make it easier to pattern match.
The mask op technically is strictly equivalent to a predicate op which would have the semantic of:

scf.if %mask {
  scf.yield %value
} else {
  scf.yield poison
}

Technically this is very generic and independent of the pipeliner but since the goal is to have an op that gets lowered after peeling the prologue/epilogue it would be nice to have it be specific. Does that make sense?

Thanks for mentioning CIRCT pipeline, I’ll look at it see if I can steal some ideas.

What do you intend to lower the predicate.mask to? From my understanding of the initial description, and our offline convo, I thought the predicate.mask ops would be folded/dced after peeling? I think region simplification on scf.if would effectively do that:

%mask = %true
%res = scf.if %mask {
  %value = arith.addf %3#0, %cst : f32
  scf.yield %value
} else {
  scf.yield poison
}

→

%mask = %true
%res = arith.addf %3#0, %cst : f32

and vice-versa for %false.

There is an existing callback provided in the interface and let user decides on how to generate predicated ops as many ops may have better representation than using scf.if and that also allows user to speculatively execute without having this code take the opinion. I would expect the code to use that to lower predicate.mask.
Some of the predicated op would be folded after peeling but some will stay and need to be lowered.

yes the folding would give exactly that