[RFC] ShapedType Type Interface Cleanup

Background

All tensor/memref/vector types currently have a shared type interface ShapedType. This interface provides the following interface methods:

  • getElementType
  • getShape: Asserts that the type is ranked and returns the shape.
  • cloneWith

There are also a few helper functions around those interface methods such as getRank, getDimSize, etc. Those also assert that the type is ranked.

Proposal

Add a new RankedShapedType interface that is a sub-interface of ShapedType, utilizing interface inheritance (⚙ D140198 [mlir] Add support for interface inheritance).

All previously ranked ShapedTypes (ranked tensor types, ranked memref types, vector types) implement RankedShapedType instead of ShapedType. (They are still considered ShapedTypes because of the interface inheritance.)

Interface methods and helper functions/members that only make sense for ranked types (e.g., getShape(), getDimSize(), kDynamicSize) are moved to RankedShapedType.


(Dashed lines indicate interface implementation.)

Prototype: ⚙ D148453 [mlir] Add RankedShapedType interface. Most changes are mechanical, but because so many parts of the code base are affected, I decided to send this as an RFC.

Benefits

Errors that were previously runtime assertion failures (e.g., calling getShape() on an unranked tensor through TensorType or ShapedType) are now compile time errors: TensorType, BaseMemRefType and ShapeType no longer have functions such as getShape().

Potential Next Step

RankedShapedType and ShapedType could be implemented via external models in the tensor/memref/vector dialects. Then, interface methods such as createDimOp and createRankOp could be added, cleaning up various parts of the code base that currently have a hard-coded switch-case statement when both memref and tensor types are supported.

3 Likes

Seems like a reasonable change to me.

Thanks for taking the time to write this up!

Big +1 here, I’ve been wanting this for a while. I don’t find it great that we have to resort to “external interfaces” for this, I would rather have the types living in a dialect with their “utilities”, but I suspect it is too much of a change right now.

Thanks for this, the runtime failures had bothered me a lot!

It seems a bit odd that an interface called ShapedType wouldn’t provide any way to query the shape. Have you considered renaming it (e.g. something like ShapedType -> MaybeShapedType + RankedShapedType -> ShapedType)? Or keeping the the query methods, but making fallability explicit (e.g. ShapedType would have std::optional<ArrayRef<int64_t>> tryGetShape(), while RankedShapedType would have ArrayRef<int64_t> getShape()).

The types have a shape but it may not be a compile-time known shape. So “maybe” does not seem right to me from this point of view.
We need ways to interact with the shape that allow to handle the un ranked case in the top class.

Yes that sounds similar to the dynamic dim case. E.g., getting an unranked shape is a shape just as an unranked dim is a dim, we just currently assert for the former and return a sentinel value for the latter. Rahul’s seem closer to consistent, but probably the most consistent is “DynamicOr<DynamicOr<int64_t>>” so that there is no reliance on sentinel either spot and it’s explicit in type. (Note: we could also remove the assert and return a sentinel ArrayRef<int64_t> that would be closer to how dims are handled, but sounds like objection about self-documenting nature of using that type).

1 Like

The structural change to ensure type safety makes sense to me here, though it could also have been done by making the rank-assuming methods on ShapedType failable and having a simple C++ wrapper class that checked the existing hasRank in its classof. One effect of this RFC is that users downstream won’t be able to have a type that can handle both unranked/ranked, we’re forcing a divide in the type system. Not that we have to support such a case, but it is something that would be broken by this RFC as presented.

Very -1 to this IMO. A current artifact of the placement of these types means that they get used by various different downstream dialects. With this, we’d now be enforcing a strong tying of a “builtin” construct to a non-builtin dialect, which feels strongly like broken layering to me.

I have no strong preference for either one, but a separate RankedShapedType seems a bit cleaner and simpler (in terms of API) to me because we don’t need to wrap the return values of getShape/getDimSize/getRank/etc. in FailureOr/std::optional.

In the upstream dialects, we usually have ranked types. E.g., all vector dialect ops operate on ranked shaped types. I have seen unranked types only with cast ops or around function calls. (Not sure about TOSA…)

ShapedType works for both ranked and unranked container types. But it would require an explicit downcast to retrieve the shape. When preparing the revision, there was one pattern that I saw multiple times:

ShapedType t;
if (t.hasRank()) {
  do_something(t.getShape());
}

This would now be written as:

ShapedType t;
if (auto r = dyn_cast<RankedShapedType>(t)) {
  do_something(r.getShape());
}

Note, we already have that divide in the type system for tensor/memref types: There is UnrankedTensorType and RankedTensorType. TensorType is for both ranked/unranked tensors. The proposed change replicates the same hierarchy for interfaces: ShapedType is for both unranked/ranked container types. (But there is no UnrankedShapedType interface at the moment.)


(The original diagram did not show unranked types.)

Do you have any thoughts about Mehdi’s suggestion to move TensorType/MemRefType/VectorType to the respective dialects? (The switch-case statements for generating dim ops (createOrFoldDimOp) is the main thing that I want to clean up.)

Actually I don’t believe external interfaces are creating the coupling you’re describing: on the opposite you can create you own mydialect.tensor_dim operation and inject the interface to use this instead of the upstream tensor.dim.
The way I see it is that the decoupling here comes from the external interface saying “someone must provide a dim op of their choice by injecting an interface” and “one such dim op is available by loading the external interface provided by the tensor dialect”.

Now, while this solution matches the current separation of builtin types and dialect, as mentioned before I’d really prefer that the types live in the same dialect as their supporting operations. That is: if you want to use the MLIR-provided tensor type, we also always have with it the minimum set of operations to manipulate it (like “getRank”, “getDim”, etc.).
Which could be done by moving these types to their own dialects (can’t really be the tensor dialect, as we wouldn’t want all the arithmetic tensor operation there, they aren’t core to the tensor type)

I don’t really see it that way in this case. That does work when the types are used in pure isolation and the different dialects are not generally intended to be used together, but that isn’t the case here AFAIK. Using external interfaces in this way effectively breaks any composability in the system, because you can’t have different external interfaces for the different abstraction layers. You would no longer be able to utilize these types within your own dialect and the in-tree dialects within the same pipeline. That for me, seems like an indication of broken layering with external interfaces in this case given how the types/in-tree dialects have been designed.

This is honestly the path I’d rather us explore. The placement of these as “builtin” is largely historical, and I’d love to move them out to a better home that didn’t burden both what we could do with the design (the types being in builtin makes it hard to support with attributes/operations/etc.) and the ecosystem (why is “foo” builtin, but not “bar”; etc.)

Does it actually have to be a new dialect? We would have to add 3 new dialects (for TensorType, MemRefType and VectorType).

Same goes for other types: Should IndexType eventually move to the IndexDialect or to a new dialect. Same for IntegerType/FloatType and the ArithDialect. Whatever approach we choose now, sets a precedent.

And what ops would the new dialect contain? Based on the current discussion, it sounds like it would contain ops that deal directly with the respective type. E.g., tensor.cast could also be in that dialect. (But not things such as tensor.extract_slice.) It could be called TensorTypeDialect.

TensorDialect, MemRefDialect and VectorDialect have dependencies on other dialects (Affine, Arith, Complex) and a bunch of interfaces. That could probably be cleaned up to some degree. But not entirely, unless we always implement all interfaces via external models. That would be another reason for putting the shaped types in new, lightweight dialects.