Introduction
This is an RFC to change how we populate dialect canonicalization patterns in MLIR.
Currently, “dialect canonicalization patterns” are just those that are declared explicitly as a dialect rewrite that are added into the canonicalization patterns list (example). Operation canonicalizations are declared individually and are not part of that bundle.
Therefore, both canonicalize transform and canonicalize pass loop over all dialects for their patterns, and then over all registered operations. Other canonicalization passes (like Tiling) have to register those by hand.
To make it easier to compose dialect canonicalization passes, we want to allow a similar pattern as dialect->getCanonicalizationPatterns() without having to include *.cpp.inc files on non-related dialects or pick individual operations to get the required patterns.
We should be able to register a single dialect’s patterns, and even register its dependencies without needing patterns like other_dialect::SomeOp::getCanonicalizationPatterns(). The principle of which ops canonicalize in a particular dialect should belong to that dialect and not other dialects, passes or general code.
Motivation
In addition to avoiding the pitfalls described above (mainly static knowledge of operations or catch-all inclusion), this proposal will help the following patterns:
- Allow canonicalization transform dialect ops to name which dialects they want to apply on (ref). Similar behaviour for the pass, which could have internal options.
- Allow canonicalization of one dialect to register canonicalization of dependent dialects first (e.g. linalg depends on tensor, memref, arith, scf, etc), without hard-coding which dialects they are (ie. take from dependency lists).
- Allow pipelines to select different dialects to canonicalize at different stages, reducing the complexity of canonicalization by reducing the number of rewrites to (greedly) pass at any given time.
Proposal
There are three ways we can do that:
- Change the core
Dialect::getCanonicalizationPattern()function to add a boolean flag (default false) to also cover the operations’ patterns in the retrieval. This is done here. - Same as (1), but with a ternary flag (get dialect patterns, operation patterns, or both).
- Add a new method
Dialect::getOpCanonicalizationPatterns()that does the second part. Users can call both or either.
The main effects of this change are:
- All dialects that have operation canonicalization patterns need to implement new methods, regardless of choice above. We can reduce the scope by making it opt-in and continue using the full scan on canonicalize pass/transform, but then we end up with two ways of doing the same thing.
- TableGen needs to generate the new declaration for all dialects. This will affect downstream dialects, if they already have dialect canonicalization patterns. You can see that on my prototype above.
- Downstream dialects with operation patterns will need to adapt. This change will force those dialects to implement the operation pattern retrieval if they use operation canonicalization AND the upstream canonicalization pass.
Change the API (1 and 2)
The main benefit is that this can become the default behaviour. When adding canonicalization patterns for a particular dialect, it’s reasonable to expect that the operations patterns will also be registered. Especially because only a few dialects (ex. Linalg and Tensor) have non-operation canonicalization patterns. Flipping the flag’s default value later is much easier than change the API again.
The main problem is the change in API. Churn and all.
Adding a new function to the API (3)
The main benefit is much smaller churn and being easier on downstreams.
The main problem is that every time we need to take a dialect’s canonicalization patterns we need to remember to call both functions.
Process
I have listed all dialects that have non-empty GET_OP_LIST in include files, so that it needs to add the ops to the dialect canonicalization method. Some dialects (Linalg, ArmSME) had more than one Ops file, so need to include all.
I tried to use Claude to find those for me, but it failed consistently and repeatedly, so I ended up using grep, sort, vim and good old copy & paste.
Also moved Linalg’s CanonicalizationPatternList to a common header, which simplifies the hand-picked operation pattern list. This is currently in Passes.h but there could be a better place to put it.
Considerations
Both canonicalization pass and transform can change to use the new flag instead of iterating through all registered operations (done that in my prototype). This should have no effect, since the same operations and dialect passes are being registered, but they are registered in a different order. Since both are currently greedy, the order should not affect much, if at all.
Some dialects (linalg, tensor) had pre-existing dialect canonicalization patterns, so I’ve added them after the operation ones, so that we know that the operations are in their canonical forms before the dialect ones run.
The builtin dialect has only two operations (module and unrealized_cast), neither of which have canonicalization, so can be ignored in this process.
Transform ops are declared per dialect, but they also don’t seem to have canonicalization, so I ignored them. Let me know if that wasn’t correct.
There could be an option to auto-generate the canonicalization pattern function from table-gen. That would probably need a new generator, since it needs to query the dialect canonicalization property, every operation canonicalization (across multiple TD files) and the dialect specific canonicalization passes (ie. linalg & tensor). This RFC does not include such change.
Prototype
References
@jpienaar @matthias-springer @ftynse @kuhar @Groverkss @nicolasvasilache @mehdi_amini @javedabsar @banach-space @asiemien @KFAF @krzysz00 @zero9178