How to Convert linalg.batch_matmul to linalg.matmul When Batch Dimension is 1?

Hello,

I am working on optimizing an MLIR model that involves the linalg.batch_matmul operation. In my case, the batch dimension is always 1, so the batch_matmul could theoretically be replaced with a simpler linalg.matmul.

Here is a simplified example of the IR:

%expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<2x3xf32> into tensor<1x2x3xf32>
%expanded_0 = tensor.expand_shape %arg1 [[0, 1], [2]] : tensor<3x2xf32> into tensor<1x3x2xf32>
%result = linalg.batch_matmul ins(%expanded, %expanded_0 : tensor<1x2x3xf32>, tensor<1x3x2xf32>) → tensor<1x2x2xf32>
%collapsed = tensor.collapse_shape %result [[0, 1], [2]] : tensor<1x2x2xf32> into tensor<2x2xf32>

In this case, the expand_shape and collapse_shape operations seem unnecessary, as the batch dimension is 1. Ideally, I would like to convert the linalg.batch_matmul operation into linalg.matmul to simplify the computation.

Is there any existing MLIR pass that can automatically perform this kind of transformation? If not, what would be the best approach to implement a custom pass or pattern to achieve this?

Thank you in advance for any insights or suggestions!

1 Like