[RFC] Elementwise ops in TCP

Here is our proposed design for elementwise ops in TCP: [Public] TCP Design - Elementwise

Please feel free to add your feedback in the document or in this thread.

This only describes the representation of elementwise ops in TCP. We will start a separate RFC for broadcasting in TCP.

Thanks,
Raghavan
(on behalf of the ML Compiler Team at Cruise)

2 Likes

Thanks for putting out a clear proposal.

I personally prefer option #1b, and I think we can make it work, but two overall community arguments on previous threads need to be taken into account:

  1. As mentioned in the document, the non-trivial costs of handling regions cannot be ignored. I’m not sure if this can (or should) be fixed, but right now, it’s a constraint.
  2. The sentiment to move tensor support away from arith and math into something else, and TCP seems like a good fit for both.

The second point is, to me, weaker, because we can also do that with option #1b. Exactly like linalg, we wrap element-wise ops in another op. So the main argument against #1 is region cost.

It seems this was explained in some meeting, but it would be nice to have it written somewhere (here would work), so that we can make informed decisions, including perhaps trying to fix the region problem.

I’m not sure if the costs are related to hierarchy (region of blocks of ops), but perhaps if we made the region similar to a graph region, the costs could be reduced enough?

In the end, if we’re fusing element-wise ops in a single tcp.elementwise, it’s because we think we can fuse them during code generation, so there shouldn’t be any complex control flow in there anyway, just use-def dependencies at most.

But without knowing what the cost @mehdi_amini was referring to, I can’t formulate a strong enough opinion to disagree with your recommendation.

One additional drawback w.r.t. #1b is the region overhead before we get to the fused form.

Lowering from frontend would initially create tcp.elementwise ops with only one payload in them (assuming fusion during lowering from frontend is not acceptable). So, the region overhead will be significant even in this case, until we perform elementwise fusion.

This was not included in the design doc earlier, so, just wanted to highlight that here.

Do you have a estimate of the overhead here. Today though the Linalg path, we are already lowering these ops to first an op with a region and single statement, and then use elementwise fusion. So we are already paying this cost in IREE. We havent measured the cost or tried to reduce it, but would be useful to know if you have more data on this.

I do not have an estimate of this overhead. @mehdi_amini mentioned about this overhead during the ODM when we presented the initial TCP spec.

The overhead is at least two-folds to me:

  • every operation instead of being just that is now defining an added region, an added block, and added wrapper operation, and an added block terminator (also the extra SSA edges to wrap things together).
  • everything that manipulate a tree of expression (pre-fusion) has to walk this nesting: you can just match and rewrite something like mul(add(x), y) → fma(x, y) as easily as before. An argument was made that we can just rewrite things post-fusion but that does not account for the complexity of expression trees where intermediate results have users. This also means that fusion is on the critical path for anything related to usual arithmetic expression matching.

So the overhead is quite important compared to regular operations: we may or may not want to pay it, but I’m raising it so we can make a conscientious decision on this and nothing is overlooked.

Also, I think it is perfectly fine to generated tcp.elementwise with a single scalar operation in the body as part of codegen. I would see this in the same vein as “codegen prepare” in LLVM backends: put the IR in a form that can be consumed by the backend/codegen.
We can still independently choose that the optimizers main canonical form pre-fusion has first class element-wise operation though.

2 Likes

Please let us know if there are any objections to proceeding with our proposed design for elementwise ops in TCP.

1 Like

I think I prefer 1b. I’ve worked with a bunch of these forms, and it is ultimately just a design choice, and I struggle to say which is better in the abstract. For this level, it is very convenient to be able to directly emit “simple fusions” like this from various frontend contexts, especially when you get into mixed precision and various sign/type representations.

I expect you will quickly need passes to both fuse and unfuse these simple elementwise fusions as different parts of the pipeline reason at different granularities.

As for overhead, at this course level, I believe we are talking about on the order of hundreds or low thousands of these ops for even “large” tensor programs. This is still orders of magnitude smaller than fully unrolled low level instructions. I don’t think that the cost at this scale will be a very significant factor.

1 Like

Can you elaborate with some examples? I don’t quite get it entirely when does it matters? Or is this something that appears when looking at programs more holistically: when things that aren’t element-wise starts to mesh-in alongside and thus you need different compositional properties?

You’re only addressing one part of the “cost” here (memory consumption) and not the constant “infra” cost that comes with manipulating these (any matching or mutation across these regions is quite much harder)

There could also be an engineering argument that if we have to be able to work with these fusions and mutate them, then better handle this uniformly across the stages of the pipeline rather than having the pre/post fusion operating with different abstractions.

Given that tcp.elementwise will have explicit inputs, isn’t it going to involve more steps to fuse / unfuse ops in that design?

I expect it will. And it isn’t so much a design as an observation that we always go here eventually – so might as well plan for it.

Maybe I should clarify: I am quite weakly supportive of 1b, and the reasons that would tip me to the alternative of a discrete list of in-dialect elementwise ops are precisely the “collateral damage” in terms of these being harder to manipulate, match, etc. That cost is hard to quantify but quite real.

That line of reasoning is what is causing me to weakly bias towards 1b. What makes me uncomfortable is whether this very limited form of fusion is worth having (i.e. it doesn’t support projection, etc) – the actual cases where it makes sense vs a more general solution may be vanishingly small in practice (or worse: disjoint enough that anything that reasons about them needs to be mindful of both these simple fusions and more general groupings).

I’m not sure it does matter in the end. I don’t have a concrete example handy, but the case I’m thinking about involved frontend expansions of various RNG algorithms. When expressed at the tensor level, this results in a lot of redundant information (i.e. fully qualified tensor types, etc) for something that is fundamentally a single scalar “subroutine”. In this kind of situation, I found myself wanting to collapse this entire “subroutine” to a scalar block that I could hoist into its own function, CSE, etc. When you grind that sort of thing through a normal tensor-codegen pipeline, it all works out in the end but ends up generating a lot of redundant code that either bloats the end result or just requires a lot of work on the compiler side to cleanup.

I’ve seen this pattern enough to have wished for a way to opt such elementwise-scalar map functions out of the tensor-codegen pipeline early and just avoid the brain damage of making a tensor-codegen pipeline do things it wasn’t designed for. Having the elementwise op support an arbitrary scalar mapping expression would make that case compose easily. Not sure if it is important enough to solve in a bespoke way, though.

2 Likes

Thanks for the clarifications @stellaraccident.

To summarize, the following seems to be the tradeoffs now:

  • Go with option 1b and pay the region cost as well as the cost of manipulating and matching ops in that form.
  • Go with option 2 and pay the cost of broadcasting scalars and relying on pattern matching to get rid of them.

Please correct me if I’m missing something here.

Since the tradeoff costs are difficult to quantify at this point, I assume we will have to choose one of them to move forward and evaluate later.

How about we go with 2) for now? Once we have the design for tcp.group (which should include a way to represent elementwise fusion as well), if we still need 1b) we can revisit this at that point.

On a related note, we are starting to work on the design of tcp.group. We will start an RFC for that once it is ready.

I still think that (1b) is interesting as it is a fusion+rank reduction to scalar. But (2) is a more traditional representation for this level and more aligned with the infra/expected use. Biasing towards (2) introduces fewer concepts, and I think that is better overall for a dialect like this.

(I do like 1b but suspect that 2 might be the right answer)

2 Likes

+1 I think (2) is also what I expected at the TCP level.

1 Like