[RFC] Initial draft of TCP Spec

Here is the initial draft of TCP spec: [Public] TCP Spec. Please feel free to add your feedback in the document or in this thread. We plan to discuss this in the ODM tomorrow (9/1).

Note that this draft contains specifications for only a handful of ops at this point. This is intentional since we just wanted to have enough to foster discussion and nail down some of the design decisions.

We have intentionally not included complex ops like gather / scatter, which probably warrants a separate discussion. For the same reasons, this draft does not include support for quantization and sparsity as well. Having said that, if you think any of the design decisions or ops spec here would affect the features / ops not yet included, please do bring them up. It would be great to understand them and figure out an appropriate way forward keeping those in mind.

FYI, here is the TCP proposal document, which includes detailed discussions by several members of the community on these design points.

Thanks,
Raghavan

4 Likes

There was a discussion in today’s ODM regarding the representation for elementwise ops in TCP. In that regard, there was a mention in the chat in zoom that “arith on tensors is semi-deprecated”.

Has that decision already been made? Is there a plan in-place to deprecate / remove it? What about math ops on tensors?

If arith/math on tensors is deprecated, and we create the same ops on TCP, isn’t that the same as just reviving the original ops? If so, I’d rather we revive the old ops and keep them on tensors there than have the same ops in two different dialects just because they’re in separate types.

Or rather as a question: why were they deprecated and how can TCP help fix whatever was broken?

It was a decision from the early days that I suspect we would not make again: the tensor type was new and fresh and there was some fascination with making it compose everywhere. In latter years, we went the other way (having dedicated dialects that are specialized for these type islands), and this is one of the last holdouts of the original philosophy.

I don’t think there is a concrete plan to deprecate/remove the arith/math support for elementwise mappings. If TCP went in tree with these ops defined over the tensor domain, that would be a strong reason to go ahead and just make arith/math scalar only. Would need to be discussed, but since it is hard to remove expressive power that exists, I suspect it becomes a lot easier to decide if tcp provides its own primitives.

As for the “why”, no one was willing to define the more advanced tensor semantics on the ops in the arith/math dialect, as it was seen that more derived dialects were the right place to take concrete opinions on broadcasting, dynamism, error cases, etc. In fact, TCP does take opinions on that in a way that would never make sense for the fundamentally scalar arith/math dialects (and other dialects may take different opinions, such as by allowing implicit broadcasting, rank dynamism, etc).

Thanks for the clarification @stellaraccident

If these ops were not being considered to be deprecated, one option to represent elementwise ops in TCP is to just use the arith / math ops directly, instead of putting them as payloads in tcp.unary, tcp.binary. This would be in line with the design goal of reusing existing ops from other dialects in TCP.

But since you mentioned that the ideal scenario would be to have arith/math operate only on scalars, I see a couple of different options here:

  1. Implement the same ops in TCP to operate on Tensors (this was one of the options discussed in the ODM). As you mentioned, this might make it easier to deprecate tensors in arith/math.
  2. Implement the same ops in Tensor dialect to operate on Tensors. TCP can just reuse these ops directly (without defining any unary / binary ops). Tensor dialect then need to have concrete opinions on broadcasting, dynamism, etc., which I’m not sure would be ideal.

This probably warrants a bigger community discussion, outside the scope of TCP.

Agree.

Right, this was part of the discussing on the call. I’m not clear on all the kinks, but I understand they’re not trivial. If we take a pragmatic approach of making everything explicit, then the semantics on element-wise arith/math becomes easier to handle. But then we have *casts ops to worry about on the arith/tensor dialects…

This was from the era of the “standard” dialect. It was a lot easier to pork-barrel things. A not unreasonable rule of thumb, from my perspective: arith ops should have a trivial conversion to an equivalent LLVM arithmetic op. There is more license in math but generally, it should lower to a library call on scalars or IR which generates an approximation.

(this is not decided – just where I think it should go)

Adding my two cents here as I lobbied for tensor support on arith to disappear.

+1 to what @stellaraccident said.

Where I would prefer to draw the line is that arith should only handle data types that are values, even in later stages of lowering. So bufferization should never need to touch them. This implies a somewhat trivial lowering to LLVM intrinsics.

As an aside, we also still have arith.select support memref but that is a different story. If someone wants to clean that up, I would be all excited but it would create a fair amount of churn purely for aesthetics.

Generally, I prefer dialects to be more single-purpose and have a somewhat clear home in the lowering stages of compilation. Partially transforming away arith, i.e. the tensor pieces of it, when going to tiling and loops, violates this nice mental model. That is a philosophical statement more than anything and I won’t argue that others should adopt this view, I merely provide it to explain where I am coming from.

Regarding the option to place these operations into the tensor dialect: I see the tensor dialect as a place where operations are grouped that operate on the tensor data structure (insert, extract, construct, etc.).

Bikeshedding a bit, but shouldn’t arith be called scalar then, to remove any notion that it’s a generic arithmetic dialect?

On the math side, do you see the same problem? Ie. should we duplicate both arith and math scalar operations on tcp as tensor operations?

Friday is perfect for bikeshedding :smiley:

scalar is descriptive of the data type it works on (but missing out on vector), yet does not allow to describe the kind of operations. If we wanted to capture both it should likely be called something like scalar_and_vector_arith or immutable_value_arith. That is not ergonomic.

On a more serious note, dialect naming is hard and I agree that one often cannot derive from the name what the dialect really does.

Not sure whether this question is addressed to me but I’ll give my view anyway.

I am not too worried about duplicating operations (but I am also biased coming from mhlo). The operation definition is not the expensive part, especially for operations that are fairly regular and where verification etc. all come from traits. Where real code duplication happens is in optimizations (e.g. algebraic simplification) and lowering to other dialects. I am not too concerned about these here, as lowering for tensor and scalar code is likely disjoint anyway.

Yeah, I always associate scalar with ops (vs vector ops), but I get your point. I had enough bikeshedding for this Friday. :slight_smile:

My concern is that other high-level dialects already have those ops (ex. mhlo) and they might not be the same across dialects. For example, broadcast semantics (explicit vs implicit). And also that we’ll lower those to linalg.generic with arith ops inside, so likely need bcast/reshape on lowering if implicit. Or is that a separate broadcast + linalg.generic(op)?

Of course, all of that can be encoded, but the more things we have at boundaries, the harder it’ll be for developers (fresh and not) to remember all the details without looking at the code (MLIR’s). Principle of least surprise should apply (even if different people get surprised by different things…).

Ultimately, you want to end up with a broadcast operation that has turned purely into index arithmetic. There are a lot of ways to get there and depending on details (whether you want to end up with vectors, where a broadcast might be a gather like operation, or scalars, where it is just a load with some index math) different ways are more convenient.

For the sake of lowering, what we have chosen in XLA, is to first go from implicit broadcasting (which we have in chlo) to explicit broadcasting. We currently then lower the non-broadcasting tensor math to a linalg.generic and the broadcast, depending on its kind, to either a linalg.generic or some specialized broadcast that can also handle one-expansion.

With this design, the lowering for element-wise math is quite regular.

We then use our fusion mechanics to get the final form, either by using linalg’s own fusion, which turns the broadcasts into affine maps on the inputs, or via tiling to vector/scalar level.

I see tcp at the level of mhlo if I understand the goal correctly. Other input dialects could be potentially more like chlo and one would lower the differences away.

But is that not an argument for having everything that is needed to express tensor-level computations for ML in tcp? Then new-comers only have to learn one dialect at that level and producers won’t ever see a soup of dialects. In lower levels, things get muddy with many dialects but that is unavoidable.

I think so.

I’m happy with anything between “a small set of ops that are common in optimisation pipelines (block/tile/fuse/reorder)” to “a comprehensive set of all ML ops in the most explicit semantics possible”.

Both can be lowered from higher dialects and provide a stable interface to optimisation pipelines to work with, and are easy to lower to other dialects. Where in that spectrum we’ll converge to, is less important from my point of view.

Transformation-wise, IMO explicit semantics is almost always easier to work with than implicit, for both computers and compiler developers. That means explicit broadcast ops and strict shape checks. How this lowers to linalg is also less important, as long as we keep semantics.

The fundamental technical thing that I found to break with arith on tensors was that when folding/canonicalizing, it is hard to materialize tensor constants with dynamic shapes. Consider a fold “x - x => 0”. With vector/scalar, you just create a “0” with the right type. With tensors and dynamic shapes, you would need to insert tensor.dim ops on something to create the “0”, which makes it a totally different class of transformation. Similarly, if you are approximating math.exp, creating the constants requires special handling of dynamic dimension sizes.

Also, if going for a “broadcast is a separate op” approach (which TCP is?), the broadcasts super easily get in the way of every arith pattern match you would want to do.

Lower in the stack, linalg.generic pulls all the elementwise ops into the scalar payload anyway so all the folds “just work” there, but not folding earlier may impact cost modeling decisions or prevent discovering optimization opportunities.

1 Like

Also, possibly controversial opinion: personally I would be happy if TCP was just a copy of an appropriate MHLO subset but with some of the dynamic shape behavior regularized, consistent use of index instead of hardcoded i32/i64, Variadic for dimensions, and a grab bag of other best practices adhered to without the legacy constraints.

Is there a lot of value-add or “innovation” to be done on the op set itself? I wonder if we can view this more as a fixing of targeted pain points on MHLO rather than a grand design effort. Given that so many of us are using MHLO as a reference point anyway it seems inevitable to happen, even if subconsciously.

3 Likes

I very much like this approach of fixing things in MHLO that would be inconsistent with a “native” MLIR design. The op set itself can be largely borrowed/seeded from MHLO.