[RFC] Canonicalize batched linalg operations with singleton batch

I propose to add canonicalization patterns to linalg.batch_matmul, and all the variants, to convert to their non batched counterparts when batch size is statically known to be 1.

For example

func.func @batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
      outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
  return %1 : tensor<1x?x?xf32>
}

would canonicalize to

  func.func @batch_matmul_tensor(%arg0: tensor<1x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
    %c2 = arith.constant 2 : index
    %c1 = arith.constant 1 : index
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x?x?xf32> into tensor<?x?xf32>
    %collapsed_0 = tensor.collapse_shape %arg1 [[0, 1], [2]] : tensor<1x?x?xf32> into tensor<?x?xf32>
    %collapsed_1 = tensor.collapse_shape %arg2 [[0, 1], [2]] : tensor<1x?x?xf32> into tensor<?x?xf32>
    %0 = linalg.matmul ins(%collapsed, %collapsed_0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%collapsed_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
    %dim = tensor.dim %arg2, %c1 : tensor<1x?x?xf32>
    %dim_2 = tensor.dim %arg2, %c2 : tensor<1x?x?xf32>
    %expanded = tensor.expand_shape %0 [[0, 1], [2]] output_shape [1, %dim, %dim_2] : tensor<?x?xf32> into tensor<1x?x?xf32>
    return %expanded : tensor<1x?x?xf32>
  }

In my mind, batched operations with batch size 1 is not canonical. Can anyone think of a reason i haven’t thought of to not do this?

I already have changes for this. It’s currently a draft PR, but if there are no oppositions i’ll make it ready for review. [mlir][linalg] Implement canonicalizer for batched linalg operations with batch size 1 by srcarroll · Pull Request #95710 · llvm/llvm-project · GitHub.

@matthias-springer @nicolasvasilache @ftynse @makslevental

@qed

@chelini @asiemien

This could just be a pattern in DropUnitDims pass: llvm-project/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp at main · llvm/llvm-project · GitHub , which would also allow controlling how the unit dim is dropped (reshape/slicing). I will defer the question of if this should be a canonicalization to others.

Ah yes I think that would be a valid approach too

I am not even sure you need a pattern in DropUnitDims, can DropUnitDims already handle this case?

The thing proposed here would change the name from batch_matmul to just matmul (or eventually one of the transposed variants), as far as I understand. Drop unit dims works on generic, and should keep doing that IMO.

I would be more in favor of making this rewrite optional.
It can be easier to handle a single operation in terms of matching or fusion rather than having to filter out all the shape changes.

In itself, batch dimension of size one is a valid representation. In this case, what is canonical is down to chosen convention IMO.

Yup. See: RFC : Update to "General Design" section of Operation Canonicalizations in MLIR

Batched operations are semantically different than their non-batched versions, even if the batch is 1. True, they have “the same lowering” but only if done straight away. If there’s a pass in the middle that looks for batched ops (even with batch = 1), then those passes will fail to match and you may lower to a less optimal code. Same is true for ND shapes whose “extra dimensions” are all ones, etc.

Other things look easier, like constant folding, but even those can be tricky. So, I’d treat all of those patterns are optional, chosen by the compiler (who builds the pipeline) and not enforced by some generic catch-all canonicalization pass.

Perhaps we should be discussing how we’re going to parametrize the canonicalizer pass instead?

First of all, thanks for taking the time for asking about whether this should be a canonicalization. This is a great precedence to set!

I am -1 on adding this as a canonicalization. batch_matmul -> matmul itself is a canonicalization. But the introduction of reshapes in the program can be hard to account for. So that to me makes it not a canonicalizations.

Adding it to FoldUnitDims makes sense, to me, but start with it being set to not kick in by default and then we can make it default.

So the reason this came to my mind in the first place is that i see patterns from downstream projects that expand tensors to use in a singleton batch operation, which IMO is not canonical. So in that context the proposed pattern would be an canonicalization since it would fold with the downstream patterns i mentioned.

However, I am convinced by comments that this should not be a canonicalization pattern. So now I guess the question is whether it belongs in DropUnitDims or not. I think this would be ok. But the issue is that, generically, it would introduce unit dims in other ops, like collapse and expand shape. So I think I should extend my implementation to look for expansions to singleton dims on operands, and fold to non-batched in that case. Does that make sense?

Drop unit dims works on generic, and should keep doing that IMO.

@ftynse , could you briefly explain why it should only be on generics?

Another option would be to just make it a pass, like convert-batched-ops-to-unbatched or something. I will defer to others on this choice. It’s easy enough to put the pattern rewrite I implemented anywhere.

I have repeatedly claimed that there is no single canonical form. We can have a bunch of “enabling” passes for specific transforms instead. I have so far ignored that thread, but it’s something we should indeed discuss.

Similar reasoning to the folks above: a generic with a dimension dropped is still a generic. A batched matmul without the batch dimension is just matmul, but a batched matmul without one of the other dimensions is not. I suppose we should also have similar transforms to go from matmul to matvec to dot.

I also see why one would want these to be applied in some fix-point manner, so maybe define the patterns and the populate function and let folks include them in passes as they see fit?

2 Likes

Strong +1

2 Likes

That was my point. There is no “true” canonical for, therefore we need to parametrizar it, so that it’s easy for any compiler to create their own canonical form, without competing with other compilers for space.

If we stop discussing what a canonical form is, and start discussing how to create multiple canonical forms, we might have a more productive progress forward.

5 Likes

but a batched matmul without one of the other dimensions is not

ah rigth DropUnitDims drops any unit dim, not just leading. ok then i would agree it doesn’t necessarily belong there.

I suppose we should also have similar transforms to go from matmul to matvec to dot.

yah i was considering doing that too

maybe define the patterns and the populate function and let folks include them in passes as they see fit?

i can live with that

therefore we need to parametrizar it, so that it’s easy for any compiler to create their own canonical form

could this be something similar to the transform dialect extension?

DropUnitDims is configured with a callback to control the application of its patterns. I would be happy having the patterns available for named ops as well (named op → generic) so users can just limit application to generics if that fits their use case. W.r.t. batch_matmul → matmul, specialization patterns could be used to recover the matmul. If written as a single pattern though, agree that might be better suited for a different pattern set.

There is a lot of design space here. The simplest is just folks constructing their own forms via populate methods, folks just forking passes and running them (which is similiar to populate but doesn’t assume its all pattern based, maybe helpers are shared only), hooks as Quinn mentioned. Some agreed upon phases with invariants. Then there is something like specifying preconditions and automatically “grouping” all that fits in this declarative esque specification. Lot of potential area and exploration :slight_smile:

Going from named to generic is a semantics-destroying transformation. Dropping unit dims is not, and is often called as a clean-up. So it may not be desirable to mix the two. One can already convert named to generic and drop unit dims on generics without the need for additional transformations.

This is a bit inverse to the core design principle of structure ops: preserving semantic information instead of destroying it only to recover it at some later point. Not saying I’m opposed to having such “raising” functionality, but we shouldn’t excessively rely on it in regular flows.

we can probably close this, but if there are any further discussions on canonicalization in general, see the RFC that @rengolin linked [RFC] Canonicalize batched linalg operations with singleton batch - #8 by asiemien .

I pivoted my original draft PR to implement pattern rewrites and populate function as @ftynse suggested, so people can use as needed. It’s basically done, but still in draft atm. If anyone is interested in chiming in there, please feel free. [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops by srcarroll · Pull Request #95710 · llvm/llvm-project · GitHub