This RFC is to discuss the current state of what “tensor” and “memref” are in MLIR, their strong coupling to ranked/unranked builtin types, and is to propose a change to the base types so that a user-defined tensor / memref could also be considered “tensor” / “memref” in MLIR’s upstream infrastructure (unless the infrastructure explicitly works with just a specific type).
Current state
There are 2 primary base classes, one for “tensor” and one for “memref”:
- TensorType - is a base class of Ranked- / Unranked- TensorType
- BaseMemRefType - is a base class of Ranked- / Unranked- MemRefType
They themselves are fairly simple:
- Inherit
mlir::Type
- Attach
SharedTypeInterface
to derived classes - Provide a couple of functions (
::classof()
and staticisValidElementType
)
In the general sense, these are base classes - they declare something derived respectively as “tensor” or “memref”. The class hierarchy looks roughly the following way:
MyTensorType → TensorType → Type, ShapedType(Interface)
MyMemrefType → BaseMemRefType → Type, ShapedType(Interface)
Conveniently, one can avoid dealing with “my” tensor / memref directly by using the respective base class.
The problem
In reality, the hand-wavy class hierarchy above is not possible. In fact, this is due to TensorType and BaseMemRefType being “intrusive” types: they actually know what they can hold. That is,
inline bool BaseMemRefType::classof(Type type) {
// thus: mlir::isa<BaseMemRefType> == mlir::isa<MemRefType, UnrankedMemRefType>
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}
inline bool TensorType::classof(Type type) {
// thus: mlir::isa<TensorType> == mlir::isa<RankedTensorType, UnrankedTensorType>
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}
(comments are mine, source code is roughly here)
Thus, in case one has a perfectly compatible tensor type that happens to be none of {RankedTensorType, UnrankedTensorType} then their tensor type is not considered a “tensor”. Well, at least when it comes down to mlir::isa<TensorType>(myTensor)
. Same reasoning applies as well to memrefs.
Worthy to mention that there’s an “extension” mechanism for tensors called encoding that allows one to store arbitrary information in a particular way into the builtin tensor. The memref’s counterpart to this would be layout. Yet, when one needs to have a custom tensor altogether (e.g. at least for the sake of parsing / printing / internal interactions), they may end up having unexpected issues. For instance, this is what we’ve faced in our downstream when trying to befriend one-shot bufferization and our own tensors / memrefs.
The proposal
The problem seems somewhat artificial. What exists are base classes that are way too knowledgeable about their derived counterparts. Luckily, there is a solution to this in the present day MLIR - type interfaces! The overall motivation for the change is (seems to be) similar to the motivation behind ShapedType changes.
Converting (base) TensorType and BaseMemRefType to type interfaces is almost a NFC, in fact, since the classes themselves do not implement any new APIs. Yet, being type interfaces, allows users to write custom tensor / memref types against these interfaces and ensure that the generic code relying on “tensor” / “memref” would still function.
The implementation PR: [mlir] Convert TensorType and BaseMemRefType to interfaces by andrey-golubev · Pull Request #133053 · llvm/llvm-project · GitHub
Potential problems
With base tensor / memref types becoming interfaces, any generic logic in MLIR is subject to potential new bugs and issues when non-builtin tensors / memrefs are involved.
Yet, in my opinion, this is fine: if a downstream project suffers with custom tensor / memref type support, they are free to suggest improvements (kind of what I do here now!).
I think overall the general infrastructure would also evolve to be better suited for various demands of downstream projects.