[RFC] Broadcasting in TCP

Here is our proposed design for broadcasting in TCP: [Public] TCP Design - Broadcast

Please feel free to add your feedback in the document or in this thread. We plan to discuss this in the ODM this week on 10/13.

Thanks,
Raghavan
(on behalf of the ML Compiler Team at Cruise)

P.S.: We shared a draft version of this document with some folks in the community to get early feedback.

Thanks for the discussion on broadcasting during the ODM today. It was extremely helpful to understand different perspectives and the tradeoffs involved. We are glad that we were able to figure out a path forward on this topic.

Here is a summary of the designs we discussed (slides).

Proposed design for tcp.broadcast

  • Performs dim-1 broadcasting
  • No rank broadcasting
    • Changes to the rank has to be done by reshape ops prior to broadcasting
  • Constraints
    • Broadcast axes must be statically known
    • Broadcast axes in the input tensor must have a statically known size = 1

Alternative design - Option 1

  • tcp.broadcast performs both rank and dim-1 broadcasting
    • Only support numpy-style rank broadcasting in tcp.broadcast
      • i.e., rank broadcasting only on the leading dims
  • Rank broadcasting on non-leading dims require a reshape

Alternative design - Option 2

  • tcp.broadcast performs both rank and dim-1 broadcasting
    • Support arbitrary rank broadcasting in tcp.broadcast

Next Steps

The following approach has been agreed upon as the way forward:

  • Start with implementing the proposed design.
  • Evaluate during implementation if this design is too verbose, thereby hampering transformations.
  • Pivot to one of the alternative designs based on the learnings, if necessary.

Please let us know if you have any questions / suggestions.

3 Likes

Recording for the discussion: Open MLIR Meeting 10-13-2022: Discussion about broadcasting in TCP - YouTube

@sanjoyd @raghavanr What would a typicall softmax-like pattern look like, when there is elementwise(broadcast(reduce()))?

If reduce removes the reduced dimensions to get the result type, then one has to expand shape before using tcp.broadcast, right?

In that case, tiling-and-fusion has to fuse an additional reshape. Fusion of reshapes is doable here, but quite hard in general case. Do you plan to tile and fuse on TCP level or later?

I always thought that broadcast should be “inverse” in some sense to reduction op, i.e. it adds dimensions instead of removing them.

We do plan to fuse on TCP, but it won’t be like linalg fusion, it will be more like a grouping that indicates that the backend (e.g. lowering to PTX via linalg or to cuDNN APIs) will figure out how to fuse the indicated region.

Not sure about tiling, I’ll let @raghavanr speak to that once he is back.

I don’t think we’ll need general reshapes, just an expand_shape that adds some degenerate dims. My understanding was that fusing these should not be a problem, but LMK if I’m missing something here.

If that symmetry is important we could define reduce as keeping the reduce dimensions, like TensorFlow and PyTorch do when keepdim/keepdims is true.

I often refer to this as the fusion planning stage, as it does not really perform any fusion in the IR mangling sense but derives a plan for what fusions should happen and how. I am open to a better name but thinking about it this way help disentangle from the “linalg-level” fusion.

I take this as another indicator that tcp aims to be a mhlo-level dialect. As we had the discussion recently as to whether a linalg.broadcast would conflict with a tcp broadcast.

3 Likes

Considering reduction as an “inverse” of broadcast seems reasonable. Given that tcp.broadcast does not change the rank, we could define tcp.reduce in the same way and retain the reduction dims. That is definitely one of the design options we want to consider.

Can you describe why fusion of reshapes is hard? This can be a good reason to keep reductions similar to broadcast in TCP.

TCP would not lower operators to loops. So, tiling should be happening at a lower level. Fusion at operator level (more like grouping, as @sanjoyd pointed out) is what we foresee TCP supporting.

In this particular case, when we just add/remove size-1 dimensions, it is not hard. In the general case, reshapes are quite ugly.

Assume you have a reshape from 8x2 → 16. And you are interested in fusing the reshape into some extract_slice op that extracts 4 elements of the result. In order to fuse, you have to find all elements that correspond to the 4-elements tile of the result. Unfortunately, it cannot be expressed as a single tile of the reshape input.

I want to point out that it is very awkward to handle such “fake” 1-extent dims. It makes it very hard to disambiguate between dimensions that are truly 1, like say single batch dimension, and fake 1 dimension introduced by how broadcasts are specified (and as seen in this conversation, reductions leaving a unit-dimension around). It is better to have dimensions represent actual program intent than to add/remove unit dimensions in the compiler stack. I pointed out with the broadcast design earlier that having unit dimensions is a numpy legacy and not really needed. Now mirroring that in reduction op specification to not have reductions drop dimensions is another red flag. Lets go back to plan good old C. This is how you would write dot product

float sum = 0.0;
float a[n], b[n];
for (i = 0; i < n ; ++i){
  sum = a[i] * b[i];
}

One would not write

float sum[0];
float a[n], b[n];
for (i = 0; i < n ; ++i){
  sum[0] = a[i] * b[i];
}

Same goes for broadcast. This is what you would prefer to write

float a[m][n], b[n], c[m][n];
for (i = 0; i < m ; ++i) {
   for (j = 0; j < n ; ++j) {
      c[i][j] = a[i][j] + b[j]
}

You would not write

float a[m][n], b[1][n], c[m][n];
for (i = 0; i < m ; ++i) {
   for (j = 0; j < n ; ++j) {
      c[i][j] = a[i][j] + b[0][j]
}

The fact that we do this in ML (and other domains that initially were developed on top of NumPy) is really legacy that we dont need to do right now. With broadcasts, I can buy the idea, that “this is what existing front ends do”. Adding a fake unit dimensions to reductions is a red flag for me (because front ends dont do this).

W.R.T handling reshapes, thats a red herring IMO. Ideal way to handle reshapes is to propagate the reshape to the boundaries, at which point it is just a metadata change.

3 Likes

I agree with that characterization from the PoV of codegen, this is why linalg.generic, linalg.broadcast, linalg.reduce encode such semantics explicitly.
Everything that wants to tile should be designed this way (or better if possible).

We still need something to abstract away the “fake” 1-extent dims until we get to the codegen.
As long as that abstraction does not ambition to be the thing on which we tile then I’d think we are fine?

I imagine at that level we also want to inject the right case disjunctions, shape asserts and all the things that are in @_sean_silva’s / Torch-MLIR area?

2 Likes