[MLIR][PDL] Extending PDL/PDLInterp/ByteCode to Enable Commutative Matching

Concluding comment:- It has been decided that the following will be implemented:

A utility function sortCommutativeOperands(Operation * commOp) will be added. It will take a commutative op commOp as input and then sort its operands based on each operand’s origins.

For example, if we have the following pattern and foo is the commOp:

s = sub a, b
d = div a, b
z0 = arg0
f = foo z0, d, s

Then, the function sortCommutativeOperands(Operation * commOp) will sort the operands of foo op to "d, s, z0" ( div < sub < arg0 ). Basically, while sorting, ops (except a ConstantLike op) come before block arguments, block arguments come before ConstantLike ops, each op is arranged alphabetically, and each block argument is arranged numerically ( arg0 < arg3 ).

And, if two operands come from the same op, we will backtrack and look even further to sort them. This backtracking can done in a breadth-first traversal.

We may also need to account for an ops’s attributes (e.g. cmpi sle vs cmpi sgt ) while deciding the sorting order. But this scenario is rare and can be handled in a follow-up revision as well, if not done in the initial one.

This algorithm is complete because two different operands cannot have the same breadth-first traversal. Note that this pass will also solve the recursion problem (which was a rightful concern of @Mogball) because the sorting is being handled using a utility function instead of using patterns.


We can see that this function will reduce the number of checks we need to write in a PDL/C++ pattern to match something.
For the above pattern:
1. With sortCommutativeOperands : No. of checks = 3.
2. Without sortCommutativeOperands : No. of checks = 15.


This is how the function can be used in pattern matching:

LogicalResult matchCommutativeOp(MyOp op, PatternRewriter &r) {
  sortCommutativeOperands(op.getOperation());
  // Do the matching logic. Maybe restore the operand order if match failed?
}
1 Like