[MLIR][PDL] Extending PDL/PDLInterp/ByteCode to Enable Commutative Matching

We want to enable commutative matching of operands via PDL. Any suggestions on the desired functionality and implementation to accomplish the said goal are sought here!

I think there are broadly two ways to achieve commutativity in PDL:

  1. PDL pattern marked commutative → One PDLInterp “matcher” function using specialized ops represents commutativity → ByteCode executes multiple permutations of matching based on the PDLInterp pattern. This has been partly (PDL to PDLInterp is missing) done here: ⚙ D118683 [MLIR][PDL] Add optional attribute to enable commutativity in PDL, ⚙ D118684 [MLIR][PDLInterp] Define new ops in PDLInterp to support commutativity, and ⚙ D118689 [MLIR][PDL] Execute nondeterministic bytecode & lower some PDLInterp ops.
  2. PDL pattern marked commutative → Multiple PDL patterns are created within the same module, each representing a permutation of the matching and the PDL pattern with the commutative marker is removed → No other changes are required in PDLInterp or ByteCode. This has been suggested here​.

Note that the changes required in approach 2 are less compared to the ones required in approach 1, which is why I am more inclined towards implementing approach 2, even though approach 1 is partly done.

CC: @ChiaHungDuan @Mogball

2 Likes

Thanks for the RFC!

My primary opposition to the bytecode change is that we should try to keep the bytecode’s opset small, in that anything that could be easily implement as higher-level rewrites should be.

First, there are no PDL-to-PDL rewrites today. We will have to pick a spot for them, either when the module is loaded or when the patterns are compiled. Second, there is the question of generating the patterns. Generating one pattern for every permutation of operands would work but would be inefficient and quickly explode in size. I wonder if we couldn’t do something like generating “canonicalization” patterns for the commutative op:

foo.op commutative (A, B, C) -> Bar

Rewrites to

foo.op (x, A, x) -> foo.op (A, x, x)
foo.op (x, x, A) -> foo.op (A, x, x)
foo.op (B, x, x) -> foo.op (x, B, x)
foo.op (x, x, B) -> foo.op (x, B, x)
foo.op (C, x, x) -> foo.op (x, x, C)
foo.op (x, C, x) -> foo.op (x, x, C)
foo.op (A, B, C) -> bar

Evidently, the patterns should be set to be non-recursive. This would limit the number of generated patterns to be quadratic in the number of operands instead of factorial…

I understand, @Mogball. So, at this point, would you say that we have discarded approach 1, fixed approach 2, and are henceforth going to discuss the specifications of approach 2? Or are we still waiting for others to confirm whether or not we fix approach 2?

I’m very excited to see this area being pushed! Thanks for bringing it up.

On to my thoughts here…

I’d like to see a bit more discussion around what the exact details of each approach would be. There are pros/cons and different ways to approach both, and it’s a tad too early for me to say which one is the one I would prefer the most just from a glance.

The thing about 1) is that it really depends on the constructs that we would need to add to the interpreter/byte code to effectively support this localized commutative matching. @Mogball summed up the big question here, which is that we need to be extremely careful about how we design the layers of the PDL compilation stack to properly compose (and be mindful about overlapping abstractions).

The thing about 2) that seems to be missing from the discussion above is that it can easily explode the complexity of lowering the PDL. It seems like right now there is an implicit assumption that either there is only one commutative operation, or that the PDL patterns are very small (neither of which we should assume). Without modifying how we specify patterns at the PDL level today, we are going to generate an exponential number of patterns (each commutative operation adds another dimension of duplication). An important thing to keep in mind here is that we compile PDL patterns at runtime (i.e. every time you run the compiler!), so we will end up paying whatever generation cost here consistently.

– River

This is really a very good discussion. In the process of using mlir, I always feel that a simpler PatternRewriter description is missing, so we use mlir itself in the project to dynamically match the pattern in the Module. We have also encountered this problem , and have not found a suitable solution, for the following situations:

pattern:
p commutative_op(commutative_op(A, B), commutative_op(C, D))
target module:
commutative_op(commutative_op(A0, B0), commutative_op(C0, D)))

We can first try to match commutative_op, and then match their inputs A, B, C,D.
There are two situations at this time:

  1. If A, B, C, D can be sorted, we can try to sort A, B, C, D in the pattern and A0, B0, C0, D0 in the module with the same sort rules, and then match the sorted array.
  2. Unfortunately there is another situation:
    A and B can be sorted, D and C cannot be sorted, maybe they are of the same type like Conv or Dot, then we need N! patterns, or a dynamic backtracking method to match.match(C, D) or match(D, C)

There is nothing that prevent some flows to precompile PDL?
JIT or AOT seem decoupled from the system, isn’t it?

1 Like

Currently, we don’t support AOT compilation. We generate the interpreter code from the PDL level, and that is the level of integration exposed at the client level. That isn’t to say it isn’t possible, but it isn’t something developed. That also brings about questions of if we want to support inter-mixed JIT and AOT use cases, how those are exposed to the interpreter, the performance characteristics, etc. All of that is mostly orthogonal to the discussion here though, given that realistically I’d prefer we avoid defining constructs that only work well in certain modes.

AOT is something that I have openly been deferring until the JIT side is in a healthy and production level state. At that point, adding AOT isn’t a huge problem technically, but there are a lot of logistical questions around how we want to structure it. For example, determining what scenarios that we want to support, how that gets exposed to clients, etc.

– River

This discussion starts to feel like we are going deeper into “solve an inverse problem” territory.
I’ve been thinking about this “match equivalent forms of a piece of IR” but from a different angle.

Basically, assume we have “ir_we_want_to_match” and “actual_ir” that are “equivalent”.
The problem is that equivalence is defined in terms of all canonicalization, foldings, commutativity, associativity, CSE and DCE (+ maybe a bunch of other patterns) in the IR.

Loosely speaking, my thinking is to build IR for sub(actual_ir, ir_we_want_to_match) and apply all those rules. If this all canonicalizes/folds away to “zero”, then there is a match.

It would likely involve cloning IR so I don’t imagine it would become commonplace.
Instead of going for combinatorial numbers of matchers, we’d still have a linear number of rules, but we’d need to play them forward and see the result.

Would such a mechanism be envisionnable ?

1 Like

If we can tackle this problem by canonicalization, that would be great! I am not knowledgeable in this area though, so I will defer to others to see how achievable it might be.

If we choose to formulate this as a combinatorial search, we would want to reduce the number of new operations in PDLInterp. Yes, we can leverage pdl_interp.foreach and pdl_interp.extract (perhaps with some generalizations), and pdl_interp.is_commutative_op could be generalized to other traits. It’s the pdl_interp.get_permutations that bothers me the most – I cannot see it being used for anything else. The whole “get all operands, permute them, iterate over the permutations, and extract operands back one by one” approach is too complex. Maybe loops are not appropriate here. Instead, we could allow ourselves define subroutines and implement a new op pdl_interp.call. Then we could write:

%0 = pdl_interp.get_operand 0 of %op
%1 = pdl_interp.get_operand 1 of %op
pdl_interp.call match_operands(%0, %1)
pdl_interp.check_is_commutative %op -> ^bb0, ^done
^bb0:
pdl_interp.call match_operands(%1, %0)
^done:
pdl_interp.finalize

+1. With subroutine calls we can also implement a kind of “position-independent” outlining to reduce the size of the bytecode as well. It would also position PDL for better extensibility (new subroutines instead of bytecode ops). But by introducing calls, we would need to rethink the PDL-to-PDLInterp lowering. Would the subroutine participate in the matcher optimization? If so, how?

In approach (2), the number of patterns won’t increase exponentially because the canonicalization/sorting patterns would be generated independently for each commutative op. That’s still quadratic patterns for each commutative op, but it’s better than factorial…

Fundamentally, there must be some difference between the operands, maybe not in the operands themselves but in their parents. Otherwise, they would be identical and the pattern wouldn’t need to be commutative.

Thanks for your comment, @nicolasvasilache. This is a nice abstract solution. Shall we discuss this in further detail, so that I can formulate an implementation for the same? Firstly. could you share an example of a sub IR? I wanted to understand what your vision is here.

Based on the various suggestions here about canonicalizing commutative operations (by sorting their operands), I am proposing a solution below that is not PDL-specific. Rather, it is a generic solution that will provide commutativity support to both PDL and C++ patterns. Kindly share your views on the same.


Solution:
We add a new pass to llvm-project called the canonicalize-commutative-ops pass. This pass walks through each op from top to bottom. When it finds a commutative op, it sorts its operands based on each operand’s origins.

For example, if we have the following pattern and foo is a commutative op:

s = sub a, b
d = div a, b
z0 = arg0
f = foo z0, d, s

Then, the canonicalize-commutative-ops pass will sort the operands of foo op to "d, s, z0" (div < sub < arg0). Basically, while sorting, ops come before block arguments, each op is arranged alphabetically, and each block argument is arranged numerically (arg0 < arg3).

And, if two operands come from the same op, we will backtrack and look even further to sort them. This backtracking can done in a breadth-first traversal.

This algorithm is complete because two different operands cannot have the same breadth-first traversal (this can be enforced by calling the cse pass before this pass). Note that this pass will also solve the recursion problem (which was a rightful concern of @Mogball) because the ops are being traversed from top to bottom exactly once.


We can see that this pass will reduce the number of checks we need to write in a PDL/C++ pattern to match something.
For the above pattern:
1. With canonicalize-commutative-ops: No. of checks = 3.
2. Without canonicalize-commutative-ops: No. of checks = 15.

CSE won’t get rid of side-effecting ops, but if the side-effecting ops look the same, it won’t matter anyways (from a matching perspective).

You will also need to account for ops’s attributes (e.g. cmpi sle vs cmpi sgt).

I’m not 100% behind the idea of a pass. Some commutative ops prefer a particular form. E.g. arith.addi/f will always move constants to the RHS. We wouldn’t want to indiscriminately run this sorting pass. It sounds more useful as a utility: sortCommutativeOperands(Operation *) which can be called by either C++ or PDL patterns.

The tricky part (especially for hand-written C++ patterns) is that the pattern will need to be aware of the sorting order and order its operand constraints accordingly. I would imagine a commutative pattern comprised of two parts:

  1. Sort the commutative operands of an op. If the order has changed, update the op in-place and return success. Otherwise return failure.
  2. Match the commutative op with the operand constraints sorted the same way as the operands would be.

E.g.

%a = a
%b = b
%0 = add %b, %a // add is commutative
%1 = sub %a, %b
%2 = foo %1, %0 // foo is commutative

And the pattern is meant to match foo(sub(a, b), add(b, a)). The pattern will need to be sorted to become foo(add(a, b), sub(a, b)) as well as the IR.

Thanks for your reply, @Mogball !

Yes!

Oh yes. We will have to consider all attributes of an op in worst case (but we can say that the worst case would be pretty rare).

This also sounds okay but calling this utility in each pattern may result in a huge overhead in terms of time. Instead, can we make a special case for the “constants in RHS” and call the function sortCommutativeOperands() or the pass canonicalize-commutative-ops only once before all the patterns of a certain pass? We could just say that we consider constant ops to come even after arg's according to our sorting rule. And, I think “constants in RHS” is the only special case we have to handle. Is that true?

Also, here:

What would the Operation * argument be?

The commutative operation.

The folder will move constant operands (stable order) to the RHS of operations. So as long as the sort has a special check to move operands produced by ConstantLike operations to the RHS, it will not conflict with that logic. That being said, some operations may have canonicalizers that move operands around. Having a separate pass could also create phase ordering problems where the subsequent pass contains patterns that depend on the sorting but which may produce operations that need to have their operands sorted.

It would be the operation that is being matched. E.g.

LogicalResult matchCommutativeOp(MyOp op, PatternRewriter &r) {
  sortCommutativeOperands(op);
  // Do matching logic. Maybe restore the operand order if match failed?
}

It will be the case that operations will have their already sorted operands re-sorted. But if that runtime is really a concern, with a utility, you can call it in a separate pass if you know you don’t need to re-sort any operands.

Yes, exactly. That is why we will call the canonicalize-commutative-ops pass before every pass whose patterns rely on the sorted ordering, just like the canonicalize pass is called multiple times in a pipeline. Again, that being said, just like the canonicalize pass, the canonicalize-commutative-ops pass will also have this inherent issue: Inside a certain arbitrary pass, if pattern 1 creates an op lacking the sorting order, and that op is an input for pattern 2, then it may not receive a sorted op and thus might need some interference to support matching. But, this issue can be discussed later I think.

I did not understand this. Could you please elaborate.

There is an issue with this. Notice how sorting the MyOp op will require that we sort all the ops used to generate this MyOp op's operands. So, roughly, we can say that, we will be sorting MyOp op plus all the ops that come before MyOp op (which is basically the worst case scenario). Do you agree with this?

By the way, I agree with the approach in general. Please feel free to move forward!

Yes. But that cost will have to be paid anyways, either by a pattern or by a pass.

If a pass is pattern-based, the canonicalization patterns can always be added to the pattern set. That being said, I still think the functionality should be exposed through a utility function so that it can be used anywhere: inside patterns, by a pass, or added as part of an op’s canonicalizer.

Sure, @Mogball. I understand your point and I will add a utility function to do this. Thanks for all your inputs.

Concluding comment:- It has been decided that the following will be implemented:

A utility function sortCommutativeOperands(Operation * commOp) will be added. It will take a commutative op commOp as input and then sort its operands based on each operand’s origins.

For example, if we have the following pattern and foo is the commOp:

s = sub a, b
d = div a, b
z0 = arg0
f = foo z0, d, s

Then, the function sortCommutativeOperands(Operation * commOp) will sort the operands of foo op to "d, s, z0" ( div < sub < arg0 ). Basically, while sorting, ops (except a ConstantLike op) come before block arguments, block arguments come before ConstantLike ops, each op is arranged alphabetically, and each block argument is arranged numerically ( arg0 < arg3 ).

And, if two operands come from the same op, we will backtrack and look even further to sort them. This backtracking can done in a breadth-first traversal.

We may also need to account for an ops’s attributes (e.g. cmpi sle vs cmpi sgt ) while deciding the sorting order. But this scenario is rare and can be handled in a follow-up revision as well, if not done in the initial one.

This algorithm is complete because two different operands cannot have the same breadth-first traversal. Note that this pass will also solve the recursion problem (which was a rightful concern of @Mogball) because the sorting is being handled using a utility function instead of using patterns.


We can see that this function will reduce the number of checks we need to write in a PDL/C++ pattern to match something.
For the above pattern:
1. With sortCommutativeOperands : No. of checks = 3.
2. Without sortCommutativeOperands : No. of checks = 15.


This is how the function can be used in pattern matching:

LogicalResult matchCommutativeOp(MyOp op, PatternRewriter &r) {
  sortCommutativeOperands(op.getOperation());
  // Do the matching logic. Maybe restore the operand order if match failed?
}
1 Like

Revision implementing the solution: ⚙ D124750 [MLIR] Add a utility to sort the operands of commutative ops.

Note that in the previous comment, the following is mentioned:

This is incorrect and there is no ordering considered between different block arguments in the revision I created. All block arguments are considered as equals. This is because there is no logical reason to have an order among them, especially because such an ordering will not contribute towards simplifying pattern matching.