Better modelling of pipeline loops

Some recent patches have landed to implement pipelining of an scf.for operation. The current setup is that the transformation uses a callback that needs the caller to supply two things

  1. A list of list of instructions, where each list is a the sequence of instructions that belong to a stage.
  2. A valid global ordering of instructions (across stages) that the generated code should use in the pipelined code.

While this is functionally enough to implement the transformation, there also seems to be a case for better modelling of pipelined computation. @ThomasRaoux and I have been iterating over a design and I am just posting what we have so far. There are still some unanswered questions, so also looking for feedback from the broader community. So this is not an RFC or a final design, but more a WIP.

Pipelined loop

Pipelining is meant to increase the dependence distance between definition of a value and its use to allow for interspersing instructions to hide the latency of instructions (like loads). For a 1D loop, the dependence between the definition and use can be modeled by a dependence vector {i, j}, where

  • i represents the number of iterations of the loop between the def and the use
  • j represents the number of lexicographic instructions between the def and the use.

These form an ordered pair. For the untransformed loop, the i value is 0. Pipelining increases this value to > 0. Also at MLIR level, the lexicographic ordering of instructions is meaningless. The final generated code can only be guaranteed to respect ordering between SSA use-def chains. So the design below tries to explicitly represent the intended dependence distance across iterations, i, after pipelining. The j value is represented through SSA use-def chains.

The design we have so far is to use two ops

  1. scf.pipeline_for which has
    • three operands for bounds of the loop : lower bound, upper bound and step
    • A variadic list of Values , iter_args, that is for loop carried dependence values
    • A region with a single block that is not isolated from above and only contains scf.pipeline_stage operations and an scf.yield terminator that yields the value to be forwarded to the next iteration
    • Also does not represent the induction variable for the loop explicitly.
  2. scf.pipeline_stage which has
    • A variadic list of Values that are either created by other scf.pipeline_stage or is an element of iter_args of the surrounding scf.pipeline_for operation
    • An List of integer attributes, size equal to the number of operands. Each integer represents the intended dependence distance between the producer stage and the consumer stage (along the loop iteration dimension)
    • A region with a single block that is not isolated from above and has an scf.yield terminator that yields the value for the result of the operation.
    • The first argument is the induction variable of the loop.

As an example, this is one of the tests from the tests

 scf.for %i0 = %c0 to %c4 step %c1 {
    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
    %A1_elem = addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
    memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref<?xf32>
  }  { __test_pipelining_loop__ }

This could be represented as

scf.pipeline_for %lb to %ub step %step {
  %0 = scf.pipeline_stage {
    ^bb0(%iv : index):
      %1 = memref.load %A[%iv] : memref<?xf32>
      scf.yield %1 : f32
    } -> f32
  scf.pipeline_stage ins(%0 : f32) dependence_distance = [1] {
    ^bb0(%iv : index, %arg0 : f32):
      %2 = addf %arg0, %cf : f32
      memref.store %2, %result[%iv] : memref<?xf32>
   }
}

When lowered to scf.for the distance between stage producing %0 and the stage consuming this is intended to be 1 (it is 0 to begin with). One could also represent a value greater than 1 if a longer latency is required. The dependence_distance attribute captures the intended dependence distance between def and the use of the value along the loop iteration dimension. i.e. i value above.

The following example shows the representation in cases with loop carried dependence.

%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
    %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
    %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
    scf.yield %A2_elem : f32
  }  { __test_pipelining_loop__ }
  return %r : f32

This could be represented as

%res = scf.pipeline_for %lb to %ub step %step iter_args(%arg0 = %init) {
  %0 = scf.pipeline_stage {
      ^bb0(%iv : index, %arg1 : f32):
        %1 = memref.load %A[%iv]: memref<?xf32>
        scf.yeild %1 : f32
    } -> f32
  %1 = scf.pipeline_stage ins(%0, %arg0 : f32, f32) dependence_distance = [1, 0] {
      ^bb0(%iv : index, %arg1 : f32, %arg2 : f32):
        %2 = addf %arg1, %arg2 : f32
        scf.yield %2 : f32
    } -> f32
  %2 = scf.pipeline_stage ins(%1 : f32) dependence_distance = [1] {
      ^bb0(%iv : index, %arg1 : f32):
        %3 = mulf %cf, %arg1 : f32
        scf.yield %3 : f32
    } -> f32
  scf.yield %2 : f32
}

Notice that the stage producing %1 uses a value from the iter_args of the surrounding scf.pipeline_for. Since the value is coming from the previous iteration of the loop the dependence distance is 1 to begin with, but the dependence_distance specifies that the intended distance is 0. This implies that the dependence between the def and the use is to be reduced.

Note that this representation seems to easily allow a change to the intended dependence distance between stages until it is lowered to scf.for.

Lowering to scf.for

One thing that we havent worked out fully, is the exact algorithm to lower the scf.pipeline_for to scf.for. The op itself though has all the information needed to be lowerable to an scf.for, i.e.

  • The DAG of dependence between stages can help emit the prologue, body of the loop and the epilogue
  • For cases where dependence distance has to be reduced, the stages can be “rotated” such that none of the dependence_distance to values from iter_args is > 0.

Next steps

To fully evaluate this design, we are lacking a good idea of what use cases this might have. If it is intended to be just an intermediate step before lowering to scf.for, then the utility of this abstraction is limited (what is currently implemented does work). At best it would just help in better debugging. Specifically it would be good to get an idea of whether there are any use cases that motivate transformations after lowering to the scf.pipeline_for but before lowering to scf.for. If those transformations are made easier by this representation, then that would motivate the way forward.

@nicolasvasilache @herhut @mehdi_amini any thoughts?

1 Like

Hi Mahesh,

This is pretty cool! I think there’s alot of space to design such a representation that would enable both software pipelined processor loops and HLS for pipelined circuits. I haven’t quite digested everything about your proposal, but I will point out the ‘staticlogic’ dialect in CIRCT has an experiment in describing something similar:

%3 = "staticlogic.pipeline"(%1#0, %1#1, %2) ( {
^bb0(%arg3: index, %arg4: index, %arg5: index):
  %4 = addi %arg3, %arg4 : index
  %5 = addi %arg3, %4 : index
  br ^bb1
^bb1:
  %6 = addi %arg5, %4 : index
  br ^bb2
^bb2:
  %7 = addi %5, %6 : index
  "staticlogic.return"(%7) : (index) -> ()
}) : (index, index, index) -> index

This is currently intended to compose with existing loop representations, but I could easily see that an approach with an explicit pipelined loop operation could be necessary. The other big difference is that we tried to represent pipeline stages with basic blocks rather than another operation, enabling SSA values to describe the feedforward paths simply. Feedback paths are, however more complex and would require some other (yet to be designed) mechanism.

I’ve been thinking such a structure would actually be useful after lowering from an scf.for loop as a further intermediate step to low-level code generation. In particular, Xilinx has exposed-pipeline processors where a high-level representation of pipelined loops could be useful. We also have HLS technology where such a pipelined loop could be mapped directly to a circuit. The staticlogic dialect was intended to be a step towards doing this in CIRCT.

Steve

We chatted about this in the CIRCT ODM yesterday, but I just want to chime in that I am also interested in abstractions for pipelined loops, whether for HLS or other accelerator design use-cases. I’m going to play around with the WIP before I say much more, but this is definitely an area I’m interested in, and I’ll try to bring some concrete use-cases.

Thanks Steve and @mikeurbach for the comments. Nice to see that people in a different use case see a use for it. Maybe with a bit of iterations, something nice can be added to core.

Interesting. Hopefully with explicit ops handling dependence through back-edges becomes “easier”. For example, you can rotate the stages to bring stage from one iteration closer to the next. I am just going off of examples, Thomas has checked in for pipelining scf.for. Wondering how complex the stages could be. Could the stage themselves have multiple blocks.

Agreed. Discussing with Mehdi, one of his suggestions was that you could use this as an intermediate stage to lower an scf.for to. I dont know much about CIRCT, but what you say indeed seems like a nice use case that can drive this.

Nice! I dont attend the CIRCT ODMs, but if this is discussed at a later date, I would love to attend (and probably so will Thomas). Looking forward to the use-cases you have in mind.

Sounds great. Maybe once we’ve had a bit of time to play around with it, we can pick a good time to discuss further.

For scf.pipeline_stage, do you really need an explicit ins list? You can instead just access based on dominance. Second, do you even need a new region holding op (scf.pipeline_stage) at all? You could just use scf.execute_region to encapsulate each of the stages. The semantics of it match what you want. You would just need to add optional attributes to it to carry any metadata needed.

It’s a bit weird to read this at the end of the proposal - looks like you are designing it without any transformation use cases. If it’s just for better debugging/viewing, it only provides improved readability over attaching debug attributes and I’m not sure adding new ops for better printed output is worthwhile.

Yes it would be great to discuss this together.

To echo what Mahesh already mentioned, the current implementation is now committed and should support a large set of scenarios. One open question was whether the transformation should try to order operations across stages, the current design could restrict the input to operation stages and always order operations from highest stage to lowest one (it would be needed to handle operations dependencies with a distance of 1), but I wanted to keep things flexible at this stage so user can pick an order. It is trivial to change or have an automatic decision if we want.

I agree this is one of the key point and that’s the main reason why I went with a more pragmatic design to start with. The main advantage of a scf.pipeline representation is that it save us from generating the prologue/epilogue early and therefore doesn’t duplicate operations early. It is not clear at this point whether there helps significantly or not.

My next step is to start using this transformation to optimize some GPU kernels (for example GEMM) by pipelining copies to shared memory and computations.

I’m looking forward to hear more about the different use cases for pipelining and iterate over the design.

I agree. Hence this is not an RFC. There is not sufficient justifications so far to go down this path, but wanted to just reach out to the community to see if this is worth-while thing to proceed on. The reason I started thinking about this was I was trying to see if there is a better way to stage the existing pipelining transformation which needs

  1. A description of which operation belongs to which stage
  2. A description of global “valid” ordering of the operations.

The last one in particular might be hard to maintain from a callers perspective. The only place where this global ordering can be reliably specified/maintained is while lowering to loops. So tried to work out if there is a different approach that might make it easier. I think the above does that, but wasn’t sure that this alone justifies implementing the op. So looking for more justification for this.

Note that the requirement for global “valid” ordering is orthogonal to the representation. We could change the current solution to pass only stages and have the pipeliner pick a safe order by scheduling the stages in descending order. This is just a tradeoff between flexibility vs usability.

The higher level representation doesn’t include ordering of ops across stages and will force picking an order later but it doesn’t really carry more information than a mapping op to stage. So the decision of going with the higher level representation should be independent of whether we want user to pass a global order.

+1 we have discussed this a bunch in the past and it would be really nice (but unclear how) to have a representation that both allows mandating pipelining and composes with other transformations.

In classical polyhedral representations, one would just advance/delay the schedule of ops (represented as a list of ops) and this would compose with other analyses and transformations expressed as affine schedules. The actual materialization as imperative code is delayed until much later (and is an explicit non-goal of MLIR).

Intuitively, there seems to be a conflict between imperative form / SSA and delaying the application of transformations. I am not sure how to reconcile those in the case of pipelining. This is a shame: I’d love to have a pipelining mechanism that can be exposed all the way to the user and that works on tensors too.

This seems to hint back at a separate scheduling language but I don’t think we’re close to that yet.

Nice! I’ve been reaching for something like this at the higher level too - I’d love to use it for pipelining tensor-level operations (linalg on tensors/etc).

What confused me (coming from HLS pipelining/modulo scheduling) at first were the non-zero dependence distances between operations from the same iteration in the source loop (e.g. load – add – mul), but then I realized that the distances actually refer to the iterations of the target (=pipelined) loop. So the proposed representation encodes the pipeline’s steady state, by making the overlap of operations/stages from different iterations of the source loop explicit. Is that correct?

How would you express initiation intervals > 1?

Interesting! It would be great to get your feedback if you are able to do something with the current solution at some point. (pattern: llvm-project/Transforms.h at 417e500668621e1275851ccf6e573a39482368b5 · llvm/llvm-project · GitHub)
This will create a N-1 prologues and epilogues with N being the number of stages.

I’ve been playing around with the pipelining transform a bit, and here are some of the thoughts I’ve had…

First off, the code documentation mentions

Software pipelining is usually done in two part. The first part of
pipelining is to schedule the loop […].
The second part is to take the schedule and generate the pipelined loop as
well as the prologue and epilogue. It is independent of the target.
This pattern only implement the second part.

I think this work fits very nicely with @jopperm 's scheduling work in CIRCT, at least in terms of separation of concerns. The current pipelining transform requires the user to define a schedule, and I’ve done a little experiment to show that the CIRCT tools can be used to produce such a schedule. This is pretty neat, so hopefully these efforts can continue to mesh.

That’s my understanding. The ascii-art in the code documentation has an example: https://github.com/llvm/llvm-project/blob/edaffebcb2a62b0195e23fe7d4ead005822865c3/mlir/include/mlir/Dialect/SCF/Transforms.h#L125-L134. After testing the transformation myself, I see something similar. This representation makes a lot of sense for software pipelining during codegen, but if we are talking about HLS, I’m not sure it is the best way to represent the to-be-synthesized hardware pipeline… which goes back to the following discussion:

I’m on the same page here, and I think it would be great to converge on a representation that is useful both for the software pipelining transform and the HLS folks who want to synthesize a hardware pipeline. In terms of use-cases, I’m looking at lowering from a representation like scf.pipeline_for (or the staticlogic.pipeline @stephenneuendorffer showed) into CIRCT dialects for describing hardware. So my use-case at the moment is mainly lowering, rather than transformation. But I can definitely see how having an explicit pipeline representation would make it easier to apply some interesting analyses/optimizations before lowering.

I’m also intrigued by the comments from @nicolasvasilache and @benvanik mentioning the tensor level. This goes back to an earlier discussion, but I’m still very interested in HLS flows starting at the tensor level, so I’ll just say that I’m interested if we can push this pipelining effort higher up the stack as well.

We are planning to have a discussion about SCF and the lowering we are interested in doing for HLS in the CIRCT ODM next week, 8/25. Anyone who is interested is welcome to join, and the meeting will be recorded as well. Link to join is on the meeting document.

2 Likes

Awesome! Thanks for the note.

Hello!

I have seen elsewhere that the discussion went on in CIRCT to build a “hardware” pipelined loop. What about the “software” version?

I can roughly understand how the current transformation works, but I am not sure how to use it. Am I supposed to manually encode the schedule in the attributes? Is there a command-line option to trigger the transformation?

The transformation doesn’t need to encode attributes, the unit tests for this transformation use attributes as a simple way to manually pick the schedule and tests specific cases but in general this is not how the transformation should be used.

In order to use the transformation this needs to be hooked to a scheduler or simple heuristic deciding in which stage each operation of the loop should go and also pick a relative order for operations. This is highly depend on what you are trying to do and the transformation doesn’t have an opinion on that.

For instance here is an example in the llvm sandbox of how it can be used with a simplified scheduling and latency modeling. This just associate a latency to the read operation and have a pseudo scheduler decide on the operation stage based on that.

Let me know if this is unclear. If you have more details on what you are trying to achieve I can try to give more details on the potential way to use this.

Thank you for the hint! My idea was to use an external scheduling tool, and transform its output into the required stages and order of operations. So I can probably build something simpler than the code you linked.

Just out of curiosity, why does the pass select a loop starting from an anchor op? What if there is more than one in the code?

That makes sense, it sounds like the right way to do it.

This is how all the codegen strategy transformations work, it is just a way to control transformations from command line. This is nothing specific to pipelining.

1 Like