Hello,
I am working on optimizing an MLIR model that involves the linalg.batch_matmul
operation. In my case, the batch dimension is always 1
, so the batch_matmul
could theoretically be replaced with a simpler linalg.matmul
.
Here is a simplified example of the IR:
%expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<2x3xf32> into tensor<1x2x3xf32>
%expanded_0 = tensor.expand_shape %arg1 [[0, 1], [2]] : tensor<3x2xf32> into tensor<1x3x2xf32>
%result = linalg.batch_matmul ins(%expanded, %expanded_0 : tensor<1x2x3xf32>, tensor<1x3x2xf32>) → tensor<1x2x2xf32>
%collapsed = tensor.collapse_shape %result [[0, 1], [2]] : tensor<1x2x2xf32> into tensor<2x2xf32>
In this case, the expand_shape
and collapse_shape
operations seem unnecessary, as the batch dimension is 1. Ideally, I would like to convert the linalg.batch_matmul
operation into linalg.matmul
to simplify the computation.
Is there any existing MLIR pass that can automatically perform this kind of transformation? If not, what would be the best approach to implement a custom pass or pattern to achieve this?
Thank you in advance for any insights or suggestions!