"Native" support for custom tensor / memref types

Hello,

In our downstream project we have a couple of custom tensor / memref types that we wish to be better integrated with the core MLIR infrastructure. State of the things as of writing this:

  • Our custom tensors / memrefs live in dedicated dialects and are fairly project-specific
  • There’s a very limited number of operations that can operate on such tensors / memrefs while most of the operations (say, 95 out of 100) do not even accept them as inputs
  • The usage of our special types is local: we do not (intend to) pass them across function boundaries

Originally, this seemed to work fine: an “aware” operation would know what to do with such tensors / memrefs, transforming them under the hood into mlir::RankedTensorType / mlir::RankedMemRefType respectively for other operations to work on.

But we’ve started seeing many complications transitioning to One-Shot Bufferization: afaict most of the logic there assumes mlir::TensorType / mlir::BaseMemRefType objects are provided. And this is not really a problem of this functionality: I don’t think it is reasonable nor possible for the core MLIR to be able to cover everything happening in downstream projects.

Internally, we’ve come up with the following ideas (for simplicity, let’s consider only tensors now):

  • Derive custom tensor types from mlir::TensorType
    • Turns out to be exceptionally problematic since mlir::TensorType assumes that it is either a mlir::RankedTensor or mlir::UnrankedTensor (e.g. see the definition of TensorType::getElementType)
  • (Given the above) derive custom tensor types from mlir::RankedTensorType - and hide special member fields behind mlir::Attribute “encoding”
    • This gives us the ability to pretend our custom tensor type is a mlir::TensorType (essential for one-shot bufferization)
    • But: we can no longer distinguish between “CustomTensor” and mlir::RankedTensor: this means that the majority of the operations would now accidentally work with custom types and we have to explicitly prohibit this - also places that had differentiating logic (e.g. Case<RankedTensorType>(doX).case<MyCustomTensorType>(doY)) would have to be revised as well. This is generally a significant, error-prone effort
  • (Most promising) Change mlir::TensorType to be a type interface
    • Under the hood, the tensor type just assumes it’s a {Ranked, Unranked}TensorType. To me, this seems to be essentially a TypeInterface candidate? With “overrides” coming for {Ranked, Unranked}TensorType in a proper manner
    • This allows our custom tensors to use the TensorType as any other type interface. Yet we’re still able to dinstinguish them from mlir::RankedTensorType
    • One-shot bufferization should just work in core MLIR, dealing with TensorType and BaseMemRefType (now interfaces). Our downstream would provide necessary extensions in places where we wish to create our own custom types

We are planning to experiment with making the TensorType / BaseMemRefType to be type interfaces (and then submitting a patch) but wanted to collect some feedback from someone more experienced in the matter: could it be that we missed something?

From what I could tell, this sparse tensor discussion could also benefit from TensorType becoming an interface.

1 Like

This I don’t know how you avoid even with type interfaces below. It seems the moment you are in the subsetting or “is a”/inheritance section of the world, that anything meant to operate on the interface should be safe for your type. So it means we get an implicit set of requirements on the base that everything implementing it and pass operating on the interfaces would need to support, but downstream users may be first detection of violations of those contracts.

Even with this usage of encoding, you can simply do cast<MyCustomTensorType>(t) (llvm-project/mlir/examples/SimpleDepType/ti-opt.cpp at c342b81870e6d3cfbeee4a8b9e8232643364381a · jpienaar/llvm-project · GitHub shows such wrapping, it’s a ~2 year old rough prototype so it’s rather rough and predates the ability for freestanding cast).

Is this all related to one-shot bufferization? E.g., this is what you want enabled for your custom type, you don’t want any arbitrary transforms or ops on tensor type to support your type and in fact they should not trigger/fail to verify.

(I’m not against making the interface you propose, I’m mostly collecting info)

(Still figuring out how to work with this system so sorry for poor replies, etc.)

There are two things:

  1. Anything operating on the interface level must assume any underlying type (custom, MLIR’s, etc.)
  2. Anything that requests specifically mlir::RankedTensorType would automatically reject custom types (and vice versa) - this is a very neat property to have when you want to limit the API of the operations. If our custom types derive from MLIR ones, things like mlir::isa<Base>(derived) are likely to succeed (which is what we’d want in the current TensorType model since it would at best assert whenever an internal impl pointer is not one of {Ranked, Unranked}TensorType)

What I meant is that the order matters:

auto type = getType(); // CustomTensorType that derives mlir::RankedTensorType
llvm::TypeSwitch<mlir::Type>(type)
.Case<RankedTensorType>(doX) // this is called
.Case<MyCustomTensorType>(doY)
;

llvm::TypeSwitch<mlir::Type>(type)
.Case<MyCustomTensorType>(doY) // this is called
.Case<RankedTensorType>(doX)
;

Similarly, any places doing if (mlir::isa<X>(type)) { ... } else if (mlir::isa<Y>(type)) { ... } are affected. Basically, with “inheriting” from mlir::RankedTensorType we’d introduce very subtle behavior changes all over the place.

Quite. We discovered this problem with One-Shot Bufferization specifically, but since the topic of making TensorType / BaseMemRefType (and potentially a bunch of other types like that) into interfaces is rather general, I figured labeling it would be somewhat misleading.

Generally, I don’t have enough experience within our downstream to tell other examples where we’ve cut corners due to this, but my gut feeling tells that this is not the last time we’d have to deal with “external tensors being second-class citizens” topics (also referenced the sparse tensor discussion in the OP that, in fact, feels very relatable).

The general problem with types and type interface is that while it provides a help for analyzing the IR, it does not help directly with transformations. For example: how would all the canonicalization in upstream dialect work? You often needs to create a new type and that seems impossible in general with an abstract type interface.

Do you have something particular in mind? From what I can tell, there would be no difference since:

inline bool TensorType::classof(Type type) {
  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}

inline bool BaseMemRefType::classof(Type type) {
  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}

(this is on not-so-outdated main).

Practically, looking at these two, I have the impression that even if someone somehow managed to create a mlir::TensorType object (which is probably impossible as there are no ctors / get() functions?), unless it is one of {Ranked, Unranked}TensorType, the object would be practically unusable. So, I imagine that all the code in MLIR always creates {Ranked, Unranked}TensorType and {Unranked}MemRefType objects everywhere, which is why changing the “base classes” to interfaces should be (at least in theory) straightforward and safe.

The problem would be that a canonicalization pattern operating on the arith dialect for example, would create RankedTensor during replacements when your input IR had your custom tensor, likely losing the important information you were hoping to convey with your own type.

I see. In our specific case, we do not expose our custom tensors / memrefs to “pure” MLIR dialects from upstream (because they don’t know what to do with them!). So before lowering to pure MLIR our types (must) disappear. I guess this could be a general idea. Thus:

is a bug in our compiler.

Now, the way custom type support might generally be done is through some extension points e.g. arith dialect could expose some type conversion function that users could reset (likely, what currently happens is smth along if (isa<RankedTensorType>(x)) {...} /* else */ cast<UnrankedTensorType>(x) where the later cast asserts if a custom type is provided.)

Note that I do not really understand a complete picture for custom tensors / memrefs support. But changing the base classes to interfaces is a start. Afterwards, it’ll probably take much more “real-world” use cases to figure out the details. Although, to be fair, maybe custom types won’t even be needed at the pure MLIR dialects at all? The mental model is: “custom {type} exists within the custom dialect boundaries”. One can solve the custom problem by lowering down such stuff, so that: MyCustomDialect → “pure” MLIR involves MyCustomType → “pure” MLIR type conversion (I believe this is how we currently manage to get rid of our custom memrefs, transforming them to “normal” ones).

Thus far the problem is that it’s not really possible to introduce custom tensors without significant hops: some stuff could be put behind the “encoding” (in tensor) / “layout” (in memref) but if you want a separate type then you’re out of luck.

(Btw, my guess is that MLIR doesn’t really handle user-specified encoding → layout transformation so users are anyway on their own if they want to pass the user-specific details across the bufferization boundary).

OK so your interface is only for some generic transformation like the bufferization. Seems like we could reduce the scope then to just this kind of transformations? I assume there is a very short list of applicability?
(also, as an aside, we have a ShapedTypeInterface already)

Actually, I was considering changing the mlir::TensorType / mlir::BaseMemRefType themselves (in BuiltinTypes.h) to interfaces. This stems from the fact that the whole of bufferization operates on these at a high-level (which is where we discovered original problem). In theory, we can introduce a “side interface” but then we’d need to rewrite a lot in One-Shot Bufferization (e.g. bufferization dialect ops operate/return mlir::TensorType/mlir::BaseMemRefType; unknown type conversions rely on mlir::TensorType (in) → mlir::BaseMemRefType (out) API).

It is a valid point that things like arith dialect could suddenly do something nasty if one uses custom tensor type and then forgets to change it before using arith ops. I wonder if we can code against it by some verification (e.g. assert(type.isa<RankedTensorType, UnrankedTensorType>()) somewhere at arith level).

Right. The problem with this is that too many things are ShapedTypes (for instance, both tensors and memrefs). So for bufferization this is too broad, for dialects dealing with tensors this is an accidental way to allow memrefs and for dialects dealing with memrefs this is an accidental way to allow tensors. (Also, probably, any other ShapedType objects could be provided, iirc something inside mlir::ElementsAttr provides a shaped type).