Chains of unrealized casts

I’m trying to modify a project to reuse the existing memref dialect, but I’m encountering some design issues during type conversion. I will try to explain them through some reduced examples that replicate the problems I’m facing. I’m sorry for the long posts, but I want to be clear as possible, also with respect to doubts about my implementation.

Consider the following IR, based on a custom dialect called mydialect in which an array-like has been defined, together with some operations able to operate on it.

%0 = mydialect.alloc : !mydialect.array<3xi64>
%1 = mydialect.process %0 : !mydialect.array<3xi64> -> !mydialect.array<6xi64>
// rest of code, with other uses of %1

Now, suppose that I have a conversion pass converting just a subset operations (i.e. only process).
Such conversion pass uses a type converter telling that the mydialect.array type should become a memref. While converting the timestwo operation, a mydialect.alloc op has to be created in order to store the results. So, after the conversion, we are in the following situation:

%0 = mydialect.alloc : !mydialect.array<3xi64>
%1 = builtin.unrealized_conversion_cast %0 : !mydialect.array<3xi64> to memref<3xi64>
%2 = mydialect.alloc : !mydialect.array<6xi64>
%3 = builtin.unrealized_conversion_cast %2 : !mydialect.array<6xi64> to memref<6xi64>
// Code to populate %3

Now I want to convert my array-allocating operations, and to achieve this I map the mydialect.alloc operation to memref’s one. A new unrealized cast is introduced automatically because of the existing usage within the already existing cast (%1 in the previous IR).

%0 = memref.alloc : memref<3xi64>
%1 = builtin.unrealized_conversion_cast %0 : memref<3xi64> to !mydialect.array<3xi64>
%2 = builtin.unrealized_conversion_cast %1 : !mydialect.array<3xi64> to memref<3xi64>
%3 = memref.alloc : memref<6xi64>
%4 = builtin.unrealized_conversion_cast %3 : memref<6xi64> to !mydialect.array<6xi64>
%5 = builtin.unrealized_conversion_cast %4 : !mydialect.array<6xi64> to memref<6xi64>
// Code to populate %5

Finally, I convert the memrefs into LLVM structs:

%0 = llvm.mlir.undef : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// Code to populate %0. Skipping as not useful.
%1 = builtin.unrealized_conversion_cast %0 : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> to memref<3xi64>
%2 = builtin.unrealized_conversion_cast %1 : memref<3xi64> to !mydialect.array<3xi64>
%3 = builtin.unrealized_conversion_cast %2 : !mydialect.array<3xi64> to memref<3xi64>
%4 = llvm.mlir.undef : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
%5 = builtin.unrealized_conversion_cast %4 : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> to memref<6xi64>
%6 = builtin.unrealized_conversion_cast %5 : memref<6xi64> to !mydialect.array<6xi64>
%7 = builtin.unrealized_conversion_cast %6 : !mydialect.array<6xi64> to memref<6xi64>
// Code to populate %7
// Somewhere in this code there will be unrealized casts converting %3 and %7 into LLVM structs, because the memref.load and memref.store operations have been converted but they were using as memref values the ones coming out from the unrealized casts.

So, in the end, I get a chain of casts that is semantically valid (struct → memref → mydialect.array → memref → struct) but can’t be folded because the casts reconciliation pass only considers them in pair.

Some possible objections:

  1. Q: The mydialect.process operation should not generate the second mydialect.alloc while being converted.
    A: Why not? The semantics of the operation is to create a new array containing the computed values, and the way to create such array is provided by the mydialect.alloc operation.
  2. Q: Why are you not converting the mydialect.alloc operation together with the mydialect.process operation? This way the unrealized casts would not be introduced and you would be fine!
    A: Because keeping its conversion separated allows me to possibly implement different lowering strategies, without the need of copy-pasting the alloc-conversion logic in each pass. From how I see the partial conversion infrastructure, this should be not only allowed but also encouraged.
  3. Q: You should run the cast reconciliation pass after the conversion to the memref dialect.
    A: The answer to this requires a modification to the previous example, so follow me two more minutes and please tell me if the explanation is not clear.

Suppose that the mydialect.process operation must be lowered straight to the LLVM dialect. This implies that the conversion pass is using a type converter that is able to obtain the LLVM representation of the mydialect.array type, that in our case consists in the chaining of the type conversions we have seen so far. In other words, the conversion of mydialect.array yields an LLVM struct.

After the first conversion we get the following IR:

%0 = mydialect.alloc : !mydialect.array<3xi64>
%1 = builtin.unrealized_conversion_cast %0 : !mydialect.array<3xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
%2 = mydialect.alloc : !mydialect.array<6xi64>
%3 = builtin.unrealized_conversion_cast %2 : !mydialect.array<6xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// Code to populate %3 (i.e. an llvm.call to an external function)

After the allocs conversion, we obtain:

%0 = memref.alloc : memref<3xi64>
%1 = builtin.unrealized_conversion_cast %0 : memref<3xi64> to !mydialect.array<3xi64>
%2 = builtin.unrealized_conversion_cast %1 : !mydialect.array<3xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
%3 = memref.alloc : memref<6xi64>
%4 = builtin.unrealized_conversion_cast %3 : memref<6xi64> to !mydialect.array<6xi64>
%5 = builtin.unrealized_conversion_cast %4 : !mydialect.array<6xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// Code to populate %5 (i.e. an llvm.call to an external function)

And finally, after the memref conversion:

%0 = llvm.mlir.undef : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// Code to populate %0. Skipping as not useful.
%1 = builtin.unrealized_conversion_cast %0 : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> to memref<3xi64>
%2 = builtin.unrealized_conversion_cast %1 : memref<3xi64> to !mydialect.array<3xi64>
%3 = builtin.unrealized_conversion_cast %2 : !mydialect.array<3xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.mlir.undef : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
%5 = builtin.unrealized_conversion_cast %4 : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> to memref<6xi64>
%6 = builtin.unrealized_conversion_cast %5 : memref<6xi64> to !mydialect.array<6xi64>
%7 = builtin.unrealized_conversion_cast %6 : !mydialect.array<6xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// Code to populate %7 (i.e. an llvm.call to an external function)

You can see how there is no possibility of applying an intra-pipeline cast reconciliation pass. Still the chain of casts is valid (struct → memref → mydialect.array → struct).

  1. Q: Related to (3): you should not go straight to the LLVM dialect.
    A: Again, why not? Even though the conversion to LLVM is often seen as the final step, I don’t see any reason for which some operations may require it during their conversion, even before the conversion pipeline is finished (i.e. the conversion of an operation requires to pass an opaque pointer to an external function, and I see no reason to populate my dialect with a opaque-pointer-like type).

Thanks for reading, I would really like to see your thoughts about this problem in order to understand if maybe the reconciliation pass itself should be modified to handle this situation.

Feel free to modify the pass.

As it is often the case, nobody needed this before so why have unused and more complex code when a simpler, shorter code covers all relevant cases? Note that, on the conceptual level, I would advise against having a state of IR with three different partially converted type systems simultaneously unless absolutely necessary. This adds a lot of complexity that I would have hard time to justify by only “why not?”.

Feel free to modify the pass.

Sure, I’ve already started looking into it and I will notify you when the patch is ready.

As it is often the case, nobody needed this before so why have unused and more complex code when a simpler, shorter code covers all relevant cases?

Yes sure, I agree on this. I was just wondering if there was some specific consideration to limit the folding to pair of casts.

Note that, on the conceptual level, I would advise against having a state of IR with three different partially converted type systems simultaneously unless absolutely necessary.

Well, I may be wrong about this but I feel like this problem never arised because conversions within the mlir codebase are thought (intentionally or not) to be “partial” with respect to the whole IR, but not with respect to a single dialect. To be honest I know this is not really true, because if we look at the Math conversions we can go to llvm but also to to libm at the same time, but still the type system is just one, that is the built-in one.

If you have another solution to this problem I would be happy to read about that, but in the last days I’ve tried different designs and it always ended in chains of casts. The only solution I’ve been found working was to populate the “mydialect → llvm conversion” (which is the only one that can handle the process operation, as it needs llvm-specific things) with patterns from other passes (i.e mydialect → memref & memref → llvm), but this goes against the whole point of having partial conversions imho, not to mention the need to inform those additional patterns to use a type converter that is not necessarily the same of the “mydialect → llvm” pass (the “memref → llvm” conversion can reuse the same type converter, but still the code is quite cluttered).

It’s just that the separation of ops into dialects is based on some commonality criteria, like “ops that serve similar purpose” or “ops that work on the same abstraction”, and “can be lowered together to another dialect” is a rather poor criterion to define the scope of a dialect. There are also often more than one way of conversion, with math-to-libm and math-approximate that you cite as an example. So the conversions are partial in the sense they convert a subset of ops that are deemed illegal at the next stage and leave the other ops alone. This is also why, internally, we also have a “full” conversion, in which only the ops explicitly listed are deemed legal.

I think the real conversion boundary is the type system in some rather loose definition because it is not necessarily defined in a single dialect. There are several occurrences of such boundaries:

  • set-of-dialects-formerly-known-as-standard to the LLVM dialect (builtin buffer-oriented type system to LLVM+builtin primitives type system);
  • set-of-dialects-formerly-known-as-standard to SPIR-V (builtin buffer-oriented type system to SPIR-V type system + some builtiins);
  • tensor dialects (frontend dialects, linalg-on-tensors) to buffer dialects (linalg-on-buffers plus loops) aka bufferization (tensors to memrefs).

My approach would be to finish one of these before starting the other. The infra may support a more complex case, but I find it just saner from complexity management perspective.

FWIW, our flows in both IREE and XLA each have a big “convert everything we know about to LLVM” pass that does exactly this. This is why the populate* functions are exposed to the user in the first place. We still need individual passes for testing and debugging though, a bag of thousand patterns applied by a benefit-based algorithm is largely undebuggable when something goes wrong.

Yes I agree about that. What I meant to say is that it’s not about the dialects definition but rather the conversions definition. As we both stated, there are conversions that deal with just a subset of a dialect, but such conversions often deal with the same types (considering the math conversions, i64 are there both before and after the conversion) and thus the chained casts do not appear. From my understandings this is one of the benefits of having a shared built-in type system (even though I think the memref & tensor stuff should not belong to it, but that’s a personal opinion and is out out the scope of this topic).

I’m not sure I understand how this intersecates the problem I’m facing. Are you saying that the conversion of the mydialect.process op (which, again, can be converted only when lowering to llvm) should not generate to a new mydialect.alloc op? Even if we imagine that process would not require such allocation, its operand (%0 in the very first code snippet) would still be subject to a chained cast: the “mydialect → memref” conversion introduces the mydialect.array -> memref cast, while the “memref → llvm” conversion introduces the memref -> struct cast.
The only way to avoid this would be to convert all the operations returning or using values having mydialect.array type in a single pass (and here it is what I was saying at the beginning of this post). If the mydialect type system disappears completely in a single pass, then no cast would be inserted. At that point I would be in the first of the three options you listed, but this implies that the work done up until that can’t leverage the possibility of converting only a subset of a dialect.

This is very useful to know and if you could please give some link I would surely look into them.

Talking about my example, the mydialect -> memref pass would be there just for debugging. Its patterns would be included into the mydialect -> llvm conversion, together with the memref -> llvm ones.
In doing so, the patterns of the mydialect -> memref conversion would require an independent type converter that converts mydialect.array into memref (while the mydialect -> llvm patterns would convert it to a struct).
The same would have to be performed in case of other types entering the pipeline: for example, I may have an earlier conversion pass for mydialect which just handles the math operations of mydialect, and in doing this creates some memref -> vector conversion. Which can get even worse if I then insert some other fancy stuff as the vector -> gpu pass.
I would be happy to be wrong, but to be honest I don’t think this is the path I want to take. The mydialect -> llvm conversion would become huge with respect to the loaded patterns. In some sense I already faced this when the unrealized cast was not yet spread into the codebase (we even had this exchange, even though I knew only a little by that time (not that now I know a lot :joy:)) and I’m glad mlir moved towards the unrealized cast usage and reconciliation.

What I am saying is that I would have considered the entire lowering pipeline holistically here to identify “check points” that have a well defined and small subset of dialects/types used. How you achieve that is a separate issue and there are multiple possibilities. You seem to have ruled out some of them by construction based on a “why-not” argument. The answer to this why not is “because it makes reasoning and lowering harder”. Some of the well-established core dialects exist mainly because they make lowering simpler and more progressive. For example, we don’t do anything very useful on SCF that we couldn’t have done elsewhere, one of the main reasons for it to exist is to factor out the complexity of going from regions to blocks. So introducing more abstractions such as opaque pointers that you have ruled out is a reasonable solution for me.

The premises that the conversion of mydialect.process emits a combination of mydialect.alloc and LLVM dialect smells like a missing abstraction. I would have done it in two stages, jumping abstraction gaps that are too wide often leads to problems at longer term. I put it in the same bucket as trying to do bufferization and vectorization at the same time: it’s not impossible, but why torture oneself with avoidable complexity?

Another possibility is to remember that there is no requirement for one call to the conversion infrastructure to be equivalent to a pass. You may call it multiple times and even have other IR modifications in between these calls. So having some cleanup / reconciliation / folding of cast(cast) is perfectly fine there.

IREE
TF KernelGen

This increasingly sounds like a bad idea complexity-wise. Trading this off for a couple of extra IR ops/types is a good deal IMO.

On the other hand this would result in the introduction of the pointer type into mydialect, which is modeling a language that does not deal with pointers. I know that a dialect is not necessarily a 1 to 1 mapping of the original language, but I think that introducing types and operations just for support reasons can easily get out of hands. However I understand what you mean, and I do not exclude that such operations will be introduced anyway at some point for the reasons you explained.

In some moments I was thinking the same. However, imagining that I introduce a bufferization pass and the mydialect.process would create in that pass the mydialect.alloc operation for the result, then what should I convert the mydialect.process operation to? The only option I can think of is to convert it to another operation mydialect.process_no_result that only takes input values (among which there would be the result array) and produces no results, but it seems like an unuseful duplication.

This however does not prevent the chains of the second part of the example (which is actually my real scenario).

Maybe I didn’t explain correctly my view, but in the second link you posted there is this:

    type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
      return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
    });
    type_converter.addConversion([&](tf_framework::JITCallableType type) {
      return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
    });

Which is still acceptable because all the other types are within the built-in type system and already handled by the LLVMTypeConverter. But in general, this is not the case if dealing with out-of-tree dialects and as you are saying it can get more complex.

In the meanwhile, I have uploaded the patch for the reconciliation pass: ⚙ D130711 [MLIR] Reconciliation of chains of unrealized casts

Nothing requires this type to be in the same dialect though, you can have a mydialect_on_pointers (like PDL and PDLInterp) or you can have a subset of operations that are not intended for the frontend users (like Async). The separation between dialects is just sugaring from the infrastructure point of view.

Ultimately, this is a design exercise and you may have different design criteria and complexity tolerance than I do. To some extent, one can think of most “intermediate” dialects in MLIR as being there for support reasons: some support optimizing transformations (e.g., Linalg), other just factor out common pieces (e.g., GPU).