Tile linalg.batch_matmul

Hello, everyone.
I‘m new to mlir. And I’m currently trying to tile a linalg.batch_matmul with a big tensor into some small tesnors. Then sum or concat all the result from the small batch_matmul.
such as:[b, m, k] · [b, k, n] = [b, m, n], tile to [b/x, m/y, k/z] · [b/x, k/z, n/q] = [b/x, m/y, n/q]. then sum all [b/x, m/y, n/q] to [b, m, n].
eg.

func.func @main(%arg0 : tensor<1x50x16xf32>, %arg1 : tensor<1x16x32xf32>, %arg2 : tensor<1x50x32xf32>) -> tensor<1x50x32xf32> {
  %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x50x16xf32>, tensor<1x16x32xf32>) outs(%arg2: tensor<1x50x32xf32>) -> tensor<1x50x32xf32>
  return %1 : tensor<1x50x32xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
    %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
      by tile_sizes = [1, 25, 8, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
      transform.yield
  }
}

I’ve already tried transform.structured.tile_using_for with expected tile_sizes. But the result seems to be uncorrect,like this Question about linalg matmul tile method - MLIR - LLVM Discussion Forums
Then I tried the transform.structured.tile_reduction_using_for to do tile.
Got this:

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  func.func @main(%arg0: tensor<1x50x16xf32>, %arg1: tensor<1x16x32xf32>, %arg2: tensor<1x50x32xf32>) -> tensor<1x50x32xf32> {
    %c4 = arith.constant 4 : index
    %c8 = arith.constant 8 : index
    %c25 = arith.constant 25 : index
    %c16 = arith.constant 16 : index
    %c32 = arith.constant 32 : index
    %c50 = arith.constant 50 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x50x32x4xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x50x32x4xf32>) -> tensor<1x50x32x4xf32>
    %2 = scf.for %arg3 = %c0 to %c50 step %c25 iter_args(%arg4 = %1) -> (tensor<1x50x32x4xf32>) {
      %3 = scf.for %arg5 = %c0 to %c32 step %c8 iter_args(%arg6 = %arg4) -> (tensor<1x50x32x4xf32>) {
        %4 = scf.for %arg7 = %c0 to %c16 step %c4 iter_args(%arg8 = %arg6) -> (tensor<1x50x32x4xf32>) {
          %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg7] [1, 25, 4] [1, 1, 1] : tensor<1x50x16xf32> to tensor<1x25x4xf32>
          %extracted_slice_0 = tensor.extract_slice %arg1[0, %arg7, %arg5] [1, 4, 8] [1, 1, 1] : tensor<1x16x32xf32> to tensor<1x4x8xf32>
          %extracted_slice_1 = tensor.extract_slice %arg8[0, 0, 0, 0] [1, 25, 8, 4] [1, 1, 1, 1] : tensor<1x50x32x4xf32> to tensor<1x25x8x4xf32>
          %5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<1x25x4xf32>, tensor<1x4x8xf32>) outs(%extracted_slice_1 : tensor<1x25x8x4xf32>) {
          ^bb0(%in: f32, %in_2: f32, %out: f32):
            %6 = arith.mulf %in, %in_2 : f32
            %7 = arith.addf %out, %6 : f32
            linalg.yield %7 : f32
          } -> tensor<1x25x8x4xf32>
          %inserted_slice = tensor.insert_slice %5 into %arg8[0, 0, 0, 0] [1, 25, 8, 4] [1, 1, 1, 1] : tensor<1x25x8x4xf32> into tensor<1x50x32x4xf32>
          scf.yield %inserted_slice : tensor<1x50x32x4xf32>
        }
        scf.yield %4 : tensor<1x50x32x4xf32>
      }
      scf.yield %3 : tensor<1x50x32x4xf32>
    }
    %reduced = linalg.reduce ins(%2 : tensor<1x50x32x4xf32>) outs(%arg2 : tensor<1x50x32xf32>) dimensions = [3]
      (%in: f32, %init: f32) {
        %3 = arith.addf %in, %init : f32
        linalg.yield %3 : f32
      }
    return %reduced : tensor<1x50x32xf32>
  }
}

It’s correct! But the result use a linalg.generic to caculate with small tesors. It lose the seamtic info with matmul.
Could you provide suggestions or point me towards resources that might help achieve this? Any insights or examples would be greatly appreciated.

Thank you in advance for your time and assistance.

@asiemien @rolfmorel @chelini

Hi @Lvhuichen

Could you elaborate on what’s incorrect for you when using transform.structured.tile_using_for %1 tile_sizes [1, 25, 8, 4]?

When I run the following file with mlir-opt -transform-interpreter -cse -canonicalize %s

func.func @main(%arg0 : tensor<1x50x16xf32>, %arg1 : tensor<1x16x32xf32>, %arg2 : tensor<1x50x32xf32>) -> tensor<1x50x32xf32> {
  %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x50x16xf32>, tensor<1x16x32xf32>) outs(%arg2: tensor<1x50x32xf32>) -> tensor<1x50x32xf32>
  return %1 : tensor<1x50x32xf32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
    %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
    %1, %loops:4 = transform.structured.tile_using_for %0
      tile_sizes [1, 25, 8, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
      transform.yield
  }
}

I get the following IR, which looks like intended to me & keeps the linalg.batch_matmul intact:

  func.func @main(%arg0: tensor<1x50x16xf32>, %arg1: tensor<1x16x32xf32>, %arg2: tensor<1x50x32xf32>) -> tensor<1x50x32xf32> {
    %c0 = arith.constant 0 : index
    %c50 = arith.constant 50 : index
    %c32 = arith.constant 32 : index
    %c16 = arith.constant 16 : index
    %c25 = arith.constant 25 : index
    %c8 = arith.constant 8 : index
    %c4 = arith.constant 4 : index
    %0 = scf.for %arg3 = %c0 to %c50 step %c25 iter_args(%arg4 = %arg2) -> (tensor<1x50x32xf32>) {
      %1 = scf.for %arg5 = %c0 to %c32 step %c8 iter_args(%arg6 = %arg4) -> (tensor<1x50x32xf32>) {
        %2 = scf.for %arg7 = %c0 to %c16 step %c4 iter_args(%arg8 = %arg6) -> (tensor<1x50x32xf32>) {
          %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg7] [1, 25, 4] [1, 1, 1] : tensor<1x50x16xf32> to tensor<1x25x4xf32>
          %extracted_slice_0 = tensor.extract_slice %arg1[0, %arg7, %arg5] [1, 4, 8] [1, 1, 1] : tensor<1x16x32xf32> to tensor<1x4x8xf32>
          %extracted_slice_1 = tensor.extract_slice %arg8[0, %arg3, %arg5] [1, 25, 8] [1, 1, 1] : tensor<1x50x32xf32> to tensor<1x25x8xf32>
          %3 = linalg.batch_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<1x25x4xf32>, tensor<1x4x8xf32>) outs(%extracted_slice_1 : tensor<1x25x8
xf32>) -> tensor<1x25x8xf32>           
          %inserted_slice = tensor.insert_slice %3 into %arg8[0, %arg3, %arg5] [1, 25, 8] [1, 1, 1] : tensor<1x25x8xf32> into tensor<1x50x32xf32>
          scf.yield %inserted_slice : tensor<1x50x32xf32>
        }
        scf.yield %2 : tensor<1x50x32xf32>
      }
      scf.yield %1 : tensor<1x50x32xf32>
    }
    return %0 : tensor<1x50x32xf32>
  }

I also checked that this IR still gives the same result on some random matrices as the original IR with just the linalg.batch_matmul. (Small note: I did not check on latest llvm-project/main - so it could be broken there, in principle.)

1 Like

Thank you for your reply. @rolfmorel
I got the same result with you.

%3 = linalg.batch_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<1x25x4xf32>, tensor<1x4x8xf32>) outs(%extracted_slice_1 : tensor<1x25x8
xf32>) -> tensor<1x25x8xf32>

It do matmul in the small tensor <1x25x4xf32> and <1x4x8xf32>, and get result <1x25x8xf32>. It’s all right.
BUT the problem is here.

%inserted_slice = tensor.insert_slice %3 into %arg8[0, %arg3, %arg5] [1, 25, 8] [1, 1, 1] : tensor<1x25x8xf32> into tensor<1x50x32xf32>

It insert the result to the orignal tensor directly. It doesn’t match the rule of matmul.
We should concat 4 tensor<1x25x8xf32> to tensor<1x25x32xf32>, and then sum 4 tensor<1x25x32xf32> to tensor<1x25x32xf32>, finall concat 2 tensor<1x25x32xf32> to tensor<1x50x32xf32>.
I hope i can describe this question. So i write this simple illegal tiling result below.

// expected tile loop,  illegal, just for example
%lhs : [1, 50, 16]
%rhs : [1, 16, 32]
%out : [1, 50, 32]

for iter0 (0 to 50 step=25):
  %concat0 = alloc:[1, 25, 32]
  for iter1 (0 to 32 step=16):
    %sum0 = alloc:[1, 25, 16], init=0
    for iter2 (0 to 16 step=8):
      %0 = extract[1, 25, 8] from [1, 50, 16] by(iter0, iter1, iter2)
      %1 = extract[1, 8, 16] from [1, 16, 32] by(iter0, iter1, iter2)
      %result = matmul([1, 25, 8], [1, 8, 16]) to [1, 25, 16]
      %sum1 = %sum0 + %result
      write %sum1 to %sum0
    write %sum0 to %concat0[1, 25, iter1]
  write %concat0 to %out[1, iter0, 32]

I hope I haven’t made any common sense errors and have clearly described this issue.

I am not sure what you mean by it is wrong. I read through the linked post and I dont see anything wrong with it.

It is one way of doing tiling. The first method is just tiling the “parallel” iteration dimensions of the matrix multiply. This transform op is tiling the parallel and reduction dimensions. It is effectively implementing split-k for batch matmul operations. It looses semantic information cause the operation that is used for split part does not have the same indexing maps and iterator types that a batch matmul has. So it cannot use a batch matmul op for it.

I am not sure what you are expecting… I think you are reading the IR wrong?

I think this is the key here. There are multiple ways of tiling that operation and this is just one of them. The problem is that the transform is called tile_using_for, which leaves to imagination what that actually means, even if the tile sizes give some indication.

Perhaps we should be thinking of some public API for transforms where generic names do generic things and specific names call the generic functions with specific parameters.

Here, we could have a tile_k and tile_batch calling tile_using_for with the right parameters (and maybe even doing a bit more on the side).

With many of us working in the upstream library of transforms and schedule composition, this was bound to be a problem sooner or later. @rolfmorel is already working on that area, but this needs more eyes from outside of our group to make sure the design we come up with fits the expectations of other users.

Hi @MaheshRavishankar . Thank you for your reply.
Yep , I had a misunderstanding before.
Consider a standard matrix multiplication A·B = C. When we tile this operation, it can be visualized as:

+——--+—-——+     +——--+—-——+     +——————————---+—-—————————-—+
| A1 | A2 |     | B1 | B2 |     | A1.B1+A2.B3 | A1.B2+A2.B4 |
+——--+—-——+ dot +——--+—-——+  =  +——————————---+—-—————————-—+
| A3 | A4 |     | B3 | B4 |     | A3.B1+A4.B3 | A3.B2+A4.B4 |
+——--+—-——+     +——--+—-——+     +——————————---+—-—————————-—+

Back to mlir. When applying this concept to MLIR’s linalg.batch_matmul, I was expecting an explicit addition operation within the innermost loop where %3 is calculated.
I thought the %3 = linalg.batch_matmul(...)= A1.B1 above. So next iteration we should add %3 with the result of A2.B3.
The result below is just use tensor.insert_slice to write result back to %arg2. No explicit add operation. So I was confused and came here for help.

// from this
func.func @main(%arg0 : tensor<1x50x16xf32>, %arg1 : tensor<1x16x32xf32>, %arg2 : tensor<1x50x32xf32>) -> tensor<1x50x32xf32> {
  %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x50x16xf32>, tensor<1x16x32xf32>) outs(%arg2: tensor<1x50x32xf32>) -> tensor<1x50x32xf32>
  return %1 : tensor<1x50x32xf32>
}

//to this
func.func @main(%arg0: tensor<1x50x16xf32>, %arg1: tensor<1x16x32xf32>, %arg2: tensor<1x50x32xf32>) -> tensor<1x50x32xf32> {
  %c0 = arith.constant 0 : index
  %c50 = arith.constant 50 : index
  %c32 = arith.constant 32 : index
  %c16 = arith.constant 16 : index
  %c25 = arith.constant 25 : index
  %c8 = arith.constant 8 : index
  %0 = scf.for %arg3 = %c0 to %c50 step %c25 iter_args(%arg4 = %arg2) -> (tensor<1x50x32xf32>) {
    %1 = scf.for %arg5 = %c0 to %c32 step %c16 iter_args(%arg6 = %arg4) -> (tensor<1x50x32xf32>) {
      %2 = scf.for %arg7 = %c0 to %c16 step %c8 iter_args(%arg8 = %arg6) -> (tensor<1x50x32xf32>) {
        %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg7] [1, 25, 8] [1, 1, 1] : tensor<1x50x16xf32> to tensor<1x25x8xf32>
        %extracted_slice_0 = tensor.extract_slice %arg1[0, %arg7, %arg5] [1, 8, 16] [1, 1, 1] : tensor<1x16x32xf32> to tensor<1x8x16xf32>
        %extracted_slice_1 = tensor.extract_slice %arg8[0, %arg3, %arg5] [1, 25, 16] [1, 1, 1] : tensor<1x50x32xf32> to tensor<1x25x16xf32>
        %3 = linalg.batch_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<1x25x8xf32>, tensor<1x8x16xf32>) outs(%extracted_slice_1 : tensor<1x25x16xf32>) -> tensor<1x25x16xf32>
        // <<I expected a add operation here before. add(%3, %extracted_slice_1)>>
        %inserted_slice = tensor.insert_slice %3 into %arg8[0, %arg3, %arg5] [1, 25, 16] [1, 1, 1] : tensor<1x25x16xf32> into tensor<1x50x32xf32>
        scf.yield %inserted_slice : tensor<1x50x32xf32>
      }
      scf.yield %2 : tensor<1x50x32xf32>
    }
    scf.yield %1 : tensor<1x50x32xf32>
  }
  return %0 : tensor<1x50x32xf32>
}

However, after lowering linalg.batch_matmul to loops, it became clear that the outs(%result) parameter serves as an accumulator.

  func.func @main(%arg0: memref<1x50x16xf32>, %arg1: memref<1x16x32xf32>, %arg2: memref<1x50x32xf32>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c50 = arith.constant 50 : index
    %c32 = arith.constant 32 : index
    %c16 = arith.constant 16 : index
    scf.for %arg3 = %c0 to %c1 step %c1 {
      scf.for %arg4 = %c0 to %c50 step %c1 {
        scf.for %arg5 = %c0 to %c32 step %c1 {
          scf.for %arg6 = %c0 to %c16 step %c1 {
            %0 = memref.load %arg0[%arg3, %arg4, %arg6] : memref<1x50x16xf32>
            %1 = memref.load %arg1[%arg3, %arg6, %arg5] : memref<1x16x32xf32>
            
             // load value from outs()
            %2 = memref.load %arg2[%arg3, %arg4, %arg5] : memref<1x50x32xf32>
            %3 = arith.mulf %0, %1 : f32
            
            // add operation is here.
            %4 = arith.addf %2, %3 : f32
            memref.store %4, %arg2[%arg3, %arg4, %arg5] : memref<1x50x32xf32>
          }
        }
      }
    }
    return
  }
}

This realization helped me understand why there wasn’t an explicit add operation in the high-level MLIR representation. Finally i recoginzed.

eg. linalg.batch_matmul ins(%lhs, %rhs) outs(%result)

// linalg.batch_matmul do
%temp_result = matmul(%lhs, %rhs)
%final_result = %temp_result + %result
write %final_result to %result

//not do
%temp_result = matmul(%lhs, %rhs)
write %temp_result to %result

So the tiling result by tile_using_for is correct!
The writer of this post (Question about linalg matmul tile method - #4 by xinyuGG) may have same misunderstanding with me. :face_with_peeking_eye:
As a beginner, Thank you sincerely! :kissing_heart: :smiling_face_with_three_hearts: :partying_face: @rengolin @MaheshRavishankar @rolfmorel

1 Like

Thanks for circling back and great explanation. Using lower to loops is a good thing to suggest to newbies. Will keep that in mind.