Partial lowering with type conversions

Hi all,

I’m working on some experiments where I need to lower from dialect A to dialect B, and am using the (very nice) MLIR conversion framework. The ops in dialect A use a dialect-specific type system, the ops in dialect B use standard dialect types, there is an “N to 1” mapping here, so things are pretty simple.

Because I have a partial lowering though, I need casts to be inserted as part of the result. For example if I have:

  %x = .... ; dialectA<int32>
  %y = dialectA.add %x, %other: (dialectA<int32>, dialectA<int32>) -> dialectA<int32>
  dialectA.use %y : dialectA<int32>

Then when I lower this I need to get something like this:

  %x = .... ; dialectA<int32>
  %tmp1 = dialectA.cast %x : (dialectA<int32>) -> i32
  %tmp2 = dialectA.cast %other : (dialectA<int32>) -> i32

  %y = std.add %tmp1, %tmp2: i32

  %tmp3 = dialectA.cast %y : (i32) -> dialectA<int32>
  dialectA.use %tmp3 : dialectA<int32>

I thought that this would be handled by implementing the type converter infra, but it doesn’t seem to get used in this case.

What is the best way to handle this? If I just implement simple operation lowering hooks, I get invalid types, producing something like this:

  %x = .... ; dialectA<int32>
  %y = std.add %x, %other: (dialectA<int32>, dialectA<int32>) -> i32.   ;;; invalid!
  dialectA.use %y : i32

Which doesn’t work so well :slight_smile:

Is the right thing to proactively generate casts in each of the operation lowering implementations? This seems like it would work, but seems pretty wasteful in terms of compile time when lowering large blocks of code.

-Chris

1 Like

Extending type conversion infrastructure touches on some of these issues but does not get to the level of automatically materializing casts around arbitrary ops.

As River says, the type conversion infra is much less advanced than the op conversion and it would be good to extend.

I’ve seen this pattern come up enough times where we can synthesize bidirectional casts along a conversion edge that I wish we had a good way to do it. Imo, the current approach, which does direct conversion by way of invalid IR that must legalize within the context of a single apply*Conversion, may be good for some things, but for a lot of the cases I run across, it would work better with my brain (and hopefully assert less) to have a mode where we convert through legal IR by way of casts.

I don’t want to minimize what is there: these are hard problems, and I’ve been using/experimenting with it for months to try to settle in my mind what a better approach would look like at the implementation level.

With what is there today, if you implement a dialectAToStd type converter, it will take care of the std.adds, I believe, but not the remaining dialectA.use. Even that though, is fiddly to get right, and we spend a lot of time debugging where we got it wrong.

I feel like the tooling could be taught to materialize casts in cases like this while still doing direct type conversion in the cases that naturally resolve to be legal (ie. dialectA.add fully converts to an op that is legal for the converted types). If that could be worked out, it would maybe split the difference between the overhead of materializing casts for everything vs inserting then as an exception when needed? No idea how to implement that though.

Yeah, this aspect of the type conversion infra is one that I run into a lot.

The key issue that I find hard to deal with using the current type conversion infrastructure is that it does not create known-valid-via-local-reasoning IR. That is, patterns can end up randomly creating invalid IR depending on the setup of the type converter and other patterns.

From having implemented this by hand a couple times, my general feeling is that there should be two phases:

  1. An “expansion” phase that applies local transformations, always ensuring that it “casts” to/from the original program values
    a. body values
    b. function signatures/returns (block signatures?) (this is a logically different problem along ABI boundaries)
  2. A “resolution” phase that globally eliminates the original program values, and emits an error if that is not possible.

I’m hoping to share some thoughts about this at the ODM this week (along with a bunch of other feedback about the MLIR batteries that I’ve learned from working on npcomp).

Some spoilers:

  1. How I lower tensor to memref in npcomp
    a. pass pipeline description describing the invariants. the intermediate stage with a bunch of materialized “casts” (it’s actually a bit more complicated than a simple cast op) actually has useful invariants and I do a pass (createResolveShapeOfOpsPass) on that form.
    b. the LowerToHybridTensorMemRef pipeline. I’m fairly happy with the interface of this pipeline (described in the pass pipeline description), but the implementation needs some more infra to get right. It’s currently very placeholder-y; just enough for elementwise adds with broadcasting.
    c. (maybe I’m old school but I tend to like separate passes/pipelines more than uber-conversions; easier to debug)
  2. lowering shapes

Hi Chris,

this question has popped several times. So far, we managed to work around the problem. The first time was around my first implementation of dialect conversion for std->llvm (which was the only one that needed type changes). We considered having and inserting an explicit cast operation, but decided against it. The prevalent view at that point was that we would throw all patterns into one big A->B->C lowering machine (which did not exist back then) and it will figure out the conversions without needing the casts. Now we know it is impractical for a variety of reasons, from debugging to “pass” ordering, to being able to assume IR validity in patterns.

Currently, I think the most general way to achieve partial lowering with type conversions is to wrap each A->B dialect op conversion into A->B and B->A casts on both sides, and run a canonicalizer that removes A->B->A and B->A->B sequences as a cleanup. This has an obvious downside of requiring one cast per operand and per result, so the IR will become ~5x larger.

I suppose it should be possible to insert cast operations only for block arguments, using TypeConverter, and for operands with mismatching type. Because ops defining the operands must be converted before the uses, it is possible to check for the type mismatch in the conversion of the user op. It feels like this should be left for each pattern to decide because the infrastructure doesn’t know which types the rewritten op supports (one may even have different flows in the pattern based on the operand type, like we do for calling conventions in std->llvm now). This can be made simpler by having a special kind of ConversionPattern that could call into TypeConverter::materializeConversion. There is still an open question of users of the converted Op that might expect different types, but don’t get converted. For those, we may need a separate “default” pattern that materializes B->A conversions that TypeConverter::materializeConversion doesn’t expect today. Then again, some “unconverted” operations may just support the casts without the type, so the B->A behavior also needs some configuration. We also cannot directly leverage the verifier subsystem because verifier failures do not necessarily indicate type mismatches.

In any case, this mechanism should be opt-in to avoid requiring all conversions to define a “cast” operation.

Unfortunately the “resolution” step of removing “A->B->A and B->A->B sequences” is quite complicated in the presence of control flow and context-dependent type conversions. I’ve had to implement it multiple times at this point and some form of global information is always needed. Some code links:

  • for lowering tf dialect ops that operate on tensor lists to tf_tensorlist dialect (code link, this is the main comment that describes the algorithm)
    • the main difficulty is the fact that in tf dialect tensorlists are represented as tensor<tf.variant>, but not all variants are tensorlists. So it’s not obvious which Values (such as block arguments, function arguments) actually need to be converted. I do a global resolution phase which adjusts block arguments and handles tf.Cast ops which cannot handle the new type but become identities after conversion. (and TBD handling other control flow constructs like scf and function arguments)
  • lowering !shape.shape to IREE’s !shapex.ranked_shape (code link to the main comment describing the core issue). The basic issues is that !shape.shape turns into !shapex.ranked_shape<…> where the contents of the ... cannot in general be easily inferred separately from the conversion itself without duplicating substantial parts of the conversion logic (but maybe we need to just bite the bullet and do that).
  • the 2 cases I mentioned above in npcomp.

While writing the tf to tf_tensorlist conversion and shape->shapex lowering, I always felt like I would have loved to have the conversion infra be more worklist-driven. That is, I would provide a set of root places (such as certain known ops) that need to be converted, then those would be converted, and the legality of the remaining program would be evaluated and further conversions would be applied from there, naturally propagating the new types only where they were needed. There’s issues with that approach related to handling control flow (e.g. how to handle multiple predecessors especially backedges; e.g. if one predecessor converts to tensor<?xf32> and the other to tensor<7xf32> then the block argument needs to be chosen to be compatible with those, and appropriate casts inserted).

I’d be super happy to brainstorm on this topic with folks! Feel free to PM me if you’d like to video chat.

I also ran into this issue and have been using a similar approach as workaround. One awkward thing about it (probably because I am using a full conversion instead of a partial conversion or maybe because my input is a mix of A and B dialect) is that in order to have an op conversion that inserts casts for “A->B”, I need to make somehow mark the “B” in “A->B” as illegal (although conceptually “B” is the legal op we want to convert into) and also mark the “B” in “A->B” after casts legal…

Yes, we have a number of such awkward dynamically legal checks too for ops based on the types we know we are converting from/to.

I’ve thought about adding some kind of helper to add a default pattern/legality for such things, but I’ve always stopped short because it feels like that might just be digging the hole deeper and there may be a more principled solution.

It feels like context-dependency is the problem here, more than partial lowering. I initially wrote TypeConverter specifically for cases where one doesn’t have an Op as context, i.e. Block arguments. This was before we had first-class regions. Maybe we should reconsider and, e.g., say that the parent op of the region that contains the blocks is in charge of converting the block arguments to get some context.

We currently maintain the invariant that defs are converted before uses, and patterns actually rely on that, sometimes implicitly by ::build functions assuming operand Values have specific types. We could certainly collect upward slices of the “root” operations and still convert them in order, but it’s unclear how it would connect to A->B->C conversion. Maybe you just want a different conversion driver that does this, but doesn’t allow A->B->C.

I was thinking that A->B pattern insert the relevant casts?

Another thing I considered, but don’t see how to implement without intrusive changes, is to have a type reconciliation hook that would follow use-def chains after the conversion and insert casts if necessary. This hook will need to somehow query the ops if they accept the given type, and I don’t see how to do that.

Sorry I’m just getting back to this - here is a sketch of an approach. In the text below, I assume that some conversions are happening that is lowering a type SrcType to DstType.

Observation: the conversion framework knows about operation conversions, and it also has the ability to register casts using the newly added addMaterialization hook. It is currently reasonable for patterns to detect that their input is a SrcType (e.g. maybe it came from a block argument) and insert a cast themselves, but it isn’t very reasonable for a pattern rewrite to go look at all the uses of their results - as folks point out above, you can defensively insert all the casts, but that is inefficient and painful - the framework should do better.

Focusing on just the result side of this problem, I think the type converter framework could impose a simple policy, and take a simple approach on this:

  1. In its forward pass applying pattern rewrite operations, it can keep track of when the result of a pattern has a different type than the source, e.g. when a result changed from SrcType to DstType. In this case, before doing the RAUW - it could add the using operations to an “operand changed type” set (or map).

  2. When applying other patterns, it can check to see if an operation is in that set. If so, and if it gets rewritten/replaced, then it gets removed from the set.

  3. After all rewrites have converged, look at the set. For each operation that had an operation that changed type but did not get rewritten, insert a cast before the user operation. A more sophisticated approach would be to try to insert few casts (e.g. looking for a nearest common dominator of the uses) but even the simplest approach would be progression here.

I think this would have some nice properties - it means that full conversions would never get casts. It means that operands would never implicitly change type (which is generally a very dubious thing to do!). It means that we’d get a lot fewer casts than the “cast everywhere” approach.

What do people think about this general approach? I’m not familiar with the internals of the conversion framework. It also isn’t clear how to handle cases where you’re doing a 1->N type legalization, but I’m not sure that can reasonably be handled…

-Chris

Forgot to post, but I started fixing this yesterday.

This is more or less what I had in mind.

IMO, we should never insert casts unless provably necessary.

I am :wink:

1->N is a non-problem for now because it isn’t possible to do without casts in current pattern rewrites anyways. The TypeConverter supports it, but there is no way in a call to replaceOp to have one value expanded into N. That would need to be changed first, but it isn’t really a type conversion problem.

– River

Oh wow, awesome, thank you!

Right - that makes total sense! Thanks River,

-Chris