[RFC] MLIR Dialect Canonicalization

Introduction

This is an RFC to change how we populate dialect canonicalization patterns in MLIR.

Currently, “dialect canonicalization patterns” are just those that are declared explicitly as a dialect rewrite that are added into the canonicalization patterns list (example). Operation canonicalizations are declared individually and are not part of that bundle.

Therefore, both canonicalize transform and canonicalize pass loop over all dialects for their patterns, and then over all registered operations. Other canonicalization passes (like Tiling) have to register those by hand.

To make it easier to compose dialect canonicalization passes, we want to allow a similar pattern as dialect->getCanonicalizationPatterns() without having to include *.cpp.inc files on non-related dialects or pick individual operations to get the required patterns.

We should be able to register a single dialect’s patterns, and even register its dependencies without needing patterns like other_dialect::SomeOp::getCanonicalizationPatterns(). The principle of which ops canonicalize in a particular dialect should belong to that dialect and not other dialects, passes or general code.

Motivation

In addition to avoiding the pitfalls described above (mainly static knowledge of operations or catch-all inclusion), this proposal will help the following patterns:

  1. Allow canonicalization transform dialect ops to name which dialects they want to apply on (ref). Similar behaviour for the pass, which could have internal options.
  2. Allow canonicalization of one dialect to register canonicalization of dependent dialects first (e.g. linalg depends on tensor, memref, arith, scf, etc), without hard-coding which dialects they are (ie. take from dependency lists).
  3. Allow pipelines to select different dialects to canonicalize at different stages, reducing the complexity of canonicalization by reducing the number of rewrites to (greedly) pass at any given time.

Proposal

There are three ways we can do that:

  1. Change the core Dialect::getCanonicalizationPattern() function to add a boolean flag (default false) to also cover the operations’ patterns in the retrieval. This is done here.
  2. Same as (1), but with a ternary flag (get dialect patterns, operation patterns, or both).
  3. Add a new method Dialect::getOpCanonicalizationPatterns() that does the second part. Users can call both or either.

The main effects of this change are:

  • All dialects that have operation canonicalization patterns need to implement new methods, regardless of choice above. We can reduce the scope by making it opt-in and continue using the full scan on canonicalize pass/transform, but then we end up with two ways of doing the same thing.
  • TableGen needs to generate the new declaration for all dialects. This will affect downstream dialects, if they already have dialect canonicalization patterns. You can see that on my prototype above.
  • Downstream dialects with operation patterns will need to adapt. This change will force those dialects to implement the operation pattern retrieval if they use operation canonicalization AND the upstream canonicalization pass.

Change the API (1 and 2)

The main benefit is that this can become the default behaviour. When adding canonicalization patterns for a particular dialect, it’s reasonable to expect that the operations patterns will also be registered. Especially because only a few dialects (ex. Linalg and Tensor) have non-operation canonicalization patterns. Flipping the flag’s default value later is much easier than change the API again.

The main problem is the change in API. Churn and all.

Adding a new function to the API (3)

The main benefit is much smaller churn and being easier on downstreams.

The main problem is that every time we need to take a dialect’s canonicalization patterns we need to remember to call both functions.

Process

I have listed all dialects that have non-empty GET_OP_LIST in include files, so that it needs to add the ops to the dialect canonicalization method. Some dialects (Linalg, ArmSME) had more than one Ops file, so need to include all.

I tried to use Claude to find those for me, but it failed consistently and repeatedly, so I ended up using grep, sort, vim and good old copy & paste.

Also moved Linalg’s CanonicalizationPatternList to a common header, which simplifies the hand-picked operation pattern list. This is currently in Passes.h but there could be a better place to put it.

Considerations

Both canonicalization pass and transform can change to use the new flag instead of iterating through all registered operations (done that in my prototype). This should have no effect, since the same operations and dialect passes are being registered, but they are registered in a different order. Since both are currently greedy, the order should not affect much, if at all.

Some dialects (linalg, tensor) had pre-existing dialect canonicalization patterns, so I’ve added them after the operation ones, so that we know that the operations are in their canonical forms before the dialect ones run.

The builtin dialect has only two operations (module and unrealized_cast), neither of which have canonicalization, so can be ignored in this process.

Transform ops are declared per dialect, but they also don’t seem to have canonicalization, so I ignored them. Let me know if that wasn’t correct.

There could be an option to auto-generate the canonicalization pattern function from table-gen. That would probably need a new generator, since it needs to query the dialect canonicalization property, every operation canonicalization (across multiple TD files) and the dialect specific canonicalization passes (ie. linalg & tensor). This RFC does not include such change.

Prototype

References

@jpienaar @matthias-springer @ftynse @kuhar @Groverkss @nicolasvasilache @mehdi_amini @javedabsar @banach-space @asiemien @KFAF @krzysz00 @zero9178

If the canonical form of an op in dialect X is a op in dialect Y, it feels like Y is a dependent dialect of Y. I worry a bit about this mechanism: this would require each pipeline to query and possibly load. But that could be separate discusion from the proposal here which is primarily about making it easy to compose rewritesets using all canonicalization patterns of a dialect.

Nit: I’d say this is simplifying the creation of new passes that can utilize subset of canonicalization. Rather than canonicalization pass in general, its new passes that uses canonicalization patterns (in particular one could have base class one templated one classes and so make it easy to instantiate - reproducers may need thinking). The general case that is supported is still every dialect in context needs to be considered for canonicalization patterns.

In the case of option 3, it (populateOpCanonicalizationPatterns) can be autogenerated and so no manual work. The populateDialectCanonicalizationPatterns would be empty by default, and populateAllCanonicalizationPatterns would be on base class and just call both.

One can still have a helper function (that’s fixed on base class) that calls both populate methods. (I prefer also naming to populate as these aren’t getters, they are populating a set and so matching with how patterns and their populate methods). If we start with just generating the populate op ones, then its a no breaking change that is just easier to consume elsewhere.

Flipping a default silently and workloads failing confusingly I’d claim is worse. We can add the two new populate methods without changing existing APIs (with the dialect one calling the existing get one in the 3 cases it exists). Folks can then get off off the get one that actually populates at their leisure, as it is just not considered a thing or used in core.

I’d just have ODS generate these. One can do the same as is done Python side and have a single file that includes the others, having a top-level file per dialect seems good to me. I think of this primarily for getting all the dialect’s op canonicalization patterns. For the non-op ones, they are IMHO in their own populate function and done as today for start. One could indeed later add a TD class for such, but I think the common case today (as only 3 dialects that have non op ones) is the former and this addresses that simply with little code needed.

The ordering they are run in (and whether all op ones have converge) is not determined just by how they are added, but also IR ordering and walk order, as well as benefit. If folks want all ops patterns to run to completion before dialect ones, one has to do additional work (e.g., two greedy rewriter applications). Which is another point towards having 2 populate methods if this is desired behavior for some flows.

Yes, separate discussion, but indeed, this is very pertinent. For now, nothing changes, as the canonicalize pass ignores the dependencies and just registers everything, one by one. Special passes (like the tiling one) just register what they need and nothing more.

But when we start loading dependent rewrites, we’ll have to define what do we mean by dependency and how will that be represented. We’ll also have to “deduplicate” somehow if two deps have the same dep and tries to register twice.

This is the key change in this (small) proposal, yes. We can ignore the general case for now, which remains the same.

That is not the proposal here, to be clear. The default will still be the same, and if we ever “flip” it, it will be after extensive research, testing, validation and benchmarking, both upstream and downstreams.

Right now, this is not that simple, nor a single case either. I want to get the general agreement first.

Fair, though, this was just a comment why I added them in that way, not a general design decision that we should propagate. I don’t think there’s a difference, just wanted to expose my thought process (in case I was totally wrong).

Thank you for putting this together. I appreciate the effort and the detailed write-up. That said, I’m currently struggling to understand the concrete problems this proposal is trying to solve and how the proposed changes address them. I’d need a bit more clarity before I can form an opinion.

Overall motivation

I’m having trouble following the introduction already. It would really help to start with a concrete example:

  • What is the current behavior?
  • What is the specific limitation or pain point?
  • What breaks or becomes difficult today?
  • What would the user experience look like after this change?

Right now, I’m not able to connect the high-level goals with the proposed API changes.


1. Dialects specifying which dialects to apply canonicalization on

Allow canonicalization transform dialect ops to name which dialects they want to apply on. Similar behaviour for the pass, which could have internal options.

What is the concrete limitation in the current APIs?

Today, we can already collect canonicalization patterns from specific dialects and operations. If I only want patterns from a single dialect, what exactly prevents me from doing that with the existing infrastructure?

It would help to see:

  • A minimal example that cannot be implemented cleanly today.
  • A short code snippet illustrating why current APIs are insufficient.

Without that, I’m not seeing why we need to change the core mechanism.


2. Registering canonicalization of dependent dialects first

Allow canonicalization of one dialect to register canonicalization of dependent dialects first (e.g. linalg depends on tensor, memref, arith, scf, etc), without hard-coding which dialects they are.

I’m not following what this achieves in practice.

  • What observable issue are we fixing?
  • What does “register canonicalization of dependent dialects first” change semantically?
  • What goes wrong today if we don’t do this?

If this is about ordering effects, determinism, or convergence properties, please spell that out explicitly. A small example showing incorrect or suboptimal behavior today would make this much clearer.


3. Selecting different dialects to canonicalize at different stages

Allow pipelines to select different dialects to canonicalize at different stages, reducing the complexity of canonicalization by reducing the number of rewrites to (greedily) pass at any given time.

I’m unclear on the goal here as well.

  • Is this about reducing compile time?
  • Improving convergence behavior?
  • Improving predictability?
  • Avoiding pathological rewrite interactions?

If the idea is to filter which patterns are applied, why is this not better expressed as a driver-level or pass-level change that filters operations or patterns, rather than modifying the core dialect API?

Again, a concrete example pipeline before/after would be very helpful.


Change to Dialect::getCanonicalizationPattern()

You propose adding a boolean flag (default false) to also retrieve operations’ patterns.

I’m having trouble connecting this specific API change to the three motivation points above.

  • How does this flag enable (1), (2), or (3)?
  • What exact behavior changes when the flag is set?
  • Why is a boolean the right abstraction here?

At the moment, the proposal feels like it jumps from high-level goals to a specific API tweak without clearly showing the causal link.


What would help

To move this forward, it would be very helpful to see:

  1. A minimal, concrete example IR and pipeline that behaves suboptimally today.
  2. A short explanation of what you would like to happen instead.
  3. How the proposed changes make that possible.
  4. Why alternative approaches (e.g. driver-side filtering or custom passes) are insufficient.

I’m very open to improving canonicalization infrastructure if there are real limitations. I just need a clearer articulation of the problem and the design space before we change a core API.

Thanks again for taking the time to draft this — I’m sure I’m missing context here, so more concrete examples would go a long way.

Sorry, this is an LLM written response that I won’t engage. If you have an actual concern that wasn’t asked by an LLM, let me know.

I wrote my concerns by hand, just iterated with an AI to help with the form, if you prefer something straight to the point and very direct, I am happy to oblige but historically this has not necessarily been better, so I attempted a different approach to engage.

Canonicalization is pretty well anchored in the framework from the beginning, so any changes likely require to be very incremental and clearly motivated, with a direct connection from a clear problem to a solution.
Here, I’m very confused about the problems to solve here and the overall motivation, I cannot follow the explanation in the introduction already, so it is difficult to make an opinion on this and the overall plan.

Points by points to begin with:

  1. Allow canonicalization transform dialect ops to name which dialects they want to apply on (ref). Similar behaviour for the pass, which could have internal options.

What do we need to change to achieve this? Why can’t we already just get patterns from a single dialect and its operation with current APIs? I don’t connect this to the “proposal”.

  1. Allow canonicalization of one dialect to register canonicalization of dependent dialects first (e.g. linalg depends on tensor, memref, arith, scf, etc), without hard-coding which dialects they are (ie. take from dependency lists).

I don’t follow what this is trying to achieve?

  1. Allow pipelines to select different dialects to canonicalize at different stages, reducing the complexity of canonicalization by reducing the number of rewrites to (greedly) pass at any given time.

I’m again confused about what this is trying to achieve: are we talking about an attempt at compile-time reduction? Why isn’t this a driver-change proposal to filter operations to match?

I can’t connect adding this boolean to the description of the problems and 1/2/3 “patterns” mentioned in the motivation. I’m likely not understanding what this is all about, so I’ll need something a bit more concrete (some examples maybe?). It’s unfortunate but we start pretty far from having a shared understanding of the “problem”, which I’m expecting isn’t gonna be conductive to converging.

Currently, ctx->getRegisteredOperations() returns all registered operations. If I want to get only the operations by a single dialect, I need to scan the whole list and filter. For the current canonicalize pass, that’s ok, since it always registers all operation canonicalizations. But for other passes that need less canonicalization patterns, this has to be done “by hand”, like the tiling and elementwise fusion patterns.

As exposed by a few LLVM talks (Alex, Javed, refs in previous docs), the canonicalize pass is too large, too long and too broad. If you’re building a compiler with MLIR for a longer pipeline (ingress, transforms, egress dialects), then greedily running the canonicalization for all dialects multiple times in this pipeline is wasteful.

Being able to run just ingress dialect (ex. torch, tosa) canonicalization at ingress, transform dialects (ex. linalg, tensor, vector) in between and only egress (ex. llvm) reduces the complexity (and thus compile time) considerably.

But knowing which dialects I need to run at every stage is a knowledge that is kept at both at the pipeline level and each dialect’s list of dependencies. So, instead of hard-coding particular lists at each stage, you can just call “canonicalize linalg” and “canonicalize llvm” and that would pull the dependent dialects canonicalization patterns together.

This would allow the core infrastructure to evolve independent of the multiple compilers that use it without the need to synchronize changes.

This is a proposal to simplify building drivers to achieve compile-time reduction. The ultimate decision is in the driver, but some details belong to the dialects themselves.

The key here is that there is no current way to get the operation canonicalization patterns for a given dialect.

The boolean is just to avoid breaking change in the API. We can also add a new method to get the operations’ canonicalization patterns.

Today, to get linalg patterns, I have to do:

  auto linalg = context->getLoadedDialect<LinalgDialect>();
  linalg->getCanonicalizationPatterns(patterns);
  // This scans the whole list for every dialect, when adding multiple.
  // For M total operations and N selected dialects this is O(M*N)
  for (RegisteredOperationName op : context->getRegisteredOperations())
    if (op.getDialect() == linalg) // or something similar
      op.getCanonicalizationPatterns(owningPatterns, context);
  }

I wanted to do only:

  auto linalg = context->getLoadedDialect<LinalgDialect>();
  linalg->getCanonicalizationPatterns(patterns, /*getOps=*/true); // Change current API

or

  auto linalg = context->getLoadedDialect<LinalgDialect>();
  linalg->getCanonicalizationPatterns(patterns); // Preserve current API
  linalg->getOpCanonicalizationPatterns(patterns); // New method

@jpienaar proposed using table-gen for this, which would reduce considerably the amount of code in my branch. We can start with a new gen for the op canon today, and add later one for the dialect rewrites, since it will need some new fields in table-gen dialect definition.

But before doing any of it, I wanted to gather the opinion of folks on what’s the best way forward.

The ability to easily obtain a per-dialect subset of canonicalization patterns sounds like a nice improvement. I’d like to better understand the practical trade-offs.

Do you have any numbers on the actual savings or gains when using a reduced pattern set? It would be great to understand the practical impact.

Dependent patterns look useful at first glance, but I wonder how much they help in practice. I suspect that if a sufficiently high-level dialect is selected, the dependency expansion could end up pulling in most patterns anyway. Do you have any insight into what the dependency tree looks like for the upstream dialects?

Also, have you considered a mechanism that aligns more directly with a concrete compiler pipeline and/or at a specific stage of it?

@ftynse presentation has more info on the trade-offs of applying all versus a minimal set:

Correct. It can reduce the savings substantically.

I wasn’t clear in my original post, but bringing the dependency patterns should be optional.

The key thing in this proposal is to just allow the dialects to register their own operation canonicalizations. Dependency tracking, API changes are all optional and mostly orthogonal at this point.

The proposal is orthogonal to actual compiler pipelines because everyone has their own. But the concrete examples upstream (tiling and fusion) are the kind of passes we need to allow to exist without having to manually select operations’ patterns.

These, or similar passes exist in virtually every tensor compiler using MLIR from a high-level entry.