Can linalg.softmax be tiled?

Hi,

I tried to use transform dialect to tile and fuse linalg.conv_2d_nhwc_hwcf, linalg.softmax, linalg.add operators, but the result is not what I expected.

Here is my codes.

module {
  func.func @conv2d_softmax_add(%input: tensor<4x66x82x64xf32>, %input1: tensor<4x64x80x64xf32>, %filter: tensor<3x3x64x64xf32>, %elementwise: tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32> {
    %init = tensor.empty() : tensor<4x64x80x64xf32>
    %conv = linalg.conv_2d_nhwc_hwcf
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
          ins(%input, %filter : tensor<4x66x82x64xf32>, tensor<3x3x64x64xf32>)
          outs(%init : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>

    %init1 = tensor.empty() : tensor<4x64x80x64xf32>
    %softmax = linalg.softmax
        dimension(3)
        ins(%conv: tensor<4x64x80x64xf32>) 
        outs(%init1 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>

    %init2 = tensor.empty() : tensor<4x64x80x64xf32>
    %add = linalg.add
        ins(%softmax, %conv : tensor<4x64x80x64xf32>, tensor<4x64x80x64xf32>) 
        outs(%init2 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>

    return %add : tensor<4x64x80x64xf32>
  }

  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
      %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %softmax = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op

      // Tile add
      %tiledOp_0, %forOp_0 = transform.structured.tile_using_for %add [1]
        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

      // Fuse (conv & softmax) into scf.for
      %fused_1_0, %containing_1_0 = transform.structured.fuse_into_containing_op %softmax into %forOp_0
        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
      %fused_1_1, %containing_1_1 = transform.structured.fuse_into_containing_op %conv into %forOp_0
        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

      // apply dce to func
      %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      transform.apply_dce to %func : !transform.any_op

      transform.yield
    }
  }
}

The CFGs corresponding to the input is shown in the below:

   Conv2d
    |   \
    |  Softmax
    |   /
    add

After executing the mlir-opt --transform-interpreter --cse linalg_conv2d_softmax_add.mlir command, the result is shown in the below:

module {
  func.func @conv2d_softmax_add(%arg0: tensor<4x66x82x64xf32>, %arg1: tensor<4x64x80x64xf32>, %arg2: tensor<3x3x64x64xf32>, %arg3: tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32> {
    %0 = tensor.empty() : tensor<4x64x80x64xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %1 = scf.for %arg4 = %c0 to %c4 step %c1 iter_args(%arg5 = %0) -> (tensor<4x64x80x64xf32>) {
      %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg2 : tensor<4x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%0 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
      %3 = linalg.softmax dimension(3) ins(%2 : tensor<4x64x80x64xf32>) outs(%0 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
      %extracted_slice = tensor.extract_slice %3[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
      %extracted_slice_0 = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 66, 82, 64] [1, 1, 1, 1] : tensor<4x66x82x64xf32> to tensor<1x66x82x64xf32>
      %extracted_slice_1 = tensor.extract_slice %arg2[0, 0, 0, 0] [3, 3, 64, 64] [1, 1, 1, 1] : tensor<3x3x64x64xf32> to tensor<3x3x64x64xf32>
      %extracted_slice_2 = tensor.extract_slice %0[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
      %4 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%extracted_slice_2 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
      %extracted_slice_3 = tensor.extract_slice %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
      %5 = linalg.add ins(%extracted_slice, %4 : tensor<1x64x80x64xf32>, tensor<1x64x80x64xf32>) outs(%extracted_slice_3 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
      %inserted_slice = tensor.insert_slice %5 into %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<1x64x80x64xf32> into tensor<4x64x80x64xf32>
      scf.yield %inserted_slice : tensor<4x64x80x64xf32>
    }
    return %1 : tensor<4x64x80x64xf32>
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["linalg.add"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %1 = transform.structured.match ops{["linalg.softmax"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %tiled_linalg_op, %loops = transform.structured.tile_using_for %0[1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
      %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1 into %loops : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
      %fused_op_0, %new_containing_op_1 = transform.structured.fuse_into_containing_op %2 into %loops : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
      %3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      transform.apply_dce to %3 : !transform.any_op
      transform.yield 
    }
  }
}

As you can see, the softmax is not Tiling, which is contrary to my expectations.

After adding --debug-only=linalg-transforms, the result is shown in the below:

[linalg-transforms]: Try to fuse a direct extract use
[linalg-transforms]: resultNumber: 0
[linalg-transforms]: Try to fuse an extract use through block argument
[linalg-transforms]: Try to fuse an use by cloning
[linalg-transforms]: resultNumber: 0
[linalg-transforms]: 
Fused an use by cloning
%5 = scf.for %arg4 = %c0 to %c4 step %c1 iter_args(%arg5 = %4) -> (tensor<4x64x80x64xf32>) {
  %6 = linalg.softmax dimension(3) ins(%1 : tensor<4x64x80x64xf32>) outs(%2 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
  %extracted_slice = tensor.extract_slice %6[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %extracted_slice_0 = tensor.extract_slice %1[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %extracted_slice_1 = tensor.extract_slice %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %7 = linalg.add ins(%extracted_slice, %extracted_slice_0 : tensor<1x64x80x64xf32>, tensor<1x64x80x64xf32>) outs(%extracted_slice_1 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
  %inserted_slice = tensor.insert_slice %7 into %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<1x64x80x64xf32> into tensor<4x64x80x64xf32>
  scf.yield %inserted_slice : tensor<4x64x80x64xf32>
}[linalg-transforms]: Try to fuse a direct extract use
[linalg-transforms]: resultNumber: 0
[linalg-transforms]: tiledProducer: %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%extracted_slice_2 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
[linalg-transforms]: 
Fused a direct extract use
%5 = scf.for %arg4 = %c0 to %c4 step %c1 iter_args(%arg5 = %4) -> (tensor<4x64x80x64xf32>) {
  %6 = linalg.softmax dimension(3) ins(%1 : tensor<4x64x80x64xf32>) outs(%2 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
  %extracted_slice = tensor.extract_slice %6[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %extracted_slice_0 = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 66, 82, 64] [1, 1, 1, 1] : tensor<4x66x82x64xf32> to tensor<1x66x82x64xf32>
  %extracted_slice_1 = tensor.extract_slice %arg2[0, 0, 0, 0] [3, 3, 64, 64] [1, 1, 1, 1] : tensor<3x3x64x64xf32> to tensor<3x3x64x64xf32>
  %extracted_slice_2 = tensor.extract_slice %0[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%extracted_slice_2 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
  %extracted_slice_3 = tensor.extract_slice %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %8 = linalg.add ins(%extracted_slice, %7 : tensor<1x64x80x64xf32>, tensor<1x64x80x64xf32>) outs(%extracted_slice_3 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
  %inserted_slice = tensor.insert_slice %8 into %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<1x64x80x64xf32> into tensor<4x64x80x64xf32>
  scf.yield %inserted_slice : tensor<4x64x80x64xf32>
}[linalg-transforms]: Try to fuse a direct extract use
[linalg-transforms]: Try to fuse an extract use through block argument
[linalg-transforms]: Try to fuse an use by cloning
[linalg-transforms]: resultNumber: 0
[linalg-transforms]: 
Fused an use by cloning
%5 = scf.for %arg4 = %c0 to %c4 step %c1 iter_args(%arg5 = %4) -> (tensor<4x64x80x64xf32>) {
  %6 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg2 : tensor<4x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%0 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
  %7 = linalg.softmax dimension(3) ins(%6 : tensor<4x64x80x64xf32>) outs(%2 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
  %extracted_slice = tensor.extract_slice %7[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %extracted_slice_0 = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 66, 82, 64] [1, 1, 1, 1] : tensor<4x66x82x64xf32> to tensor<1x66x82x64xf32>
  %extracted_slice_1 = tensor.extract_slice %arg2[0, 0, 0, 0] [3, 3, 64, 64] [1, 1, 1, 1] : tensor<3x3x64x64xf32> to tensor<3x3x64x64xf32>
  %extracted_slice_2 = tensor.extract_slice %0[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%extracted_slice_2 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
  %extracted_slice_3 = tensor.extract_slice %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
  %9 = linalg.add ins(%extracted_slice, %8 : tensor<1x64x80x64xf32>, tensor<1x64x80x64xf32>) outs(%extracted_slice_3 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
  %inserted_slice = tensor.insert_slice %9 into %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<1x64x80x64xf32> into tensor<4x64x80x64xf32>
  scf.yield %inserted_slice : tensor<4x64x80x64xf32>
}module {
  func.func @conv2d_softmax_add(%arg0: tensor<4x66x82x64xf32>, %arg1: tensor<4x64x80x64xf32>, %arg2: tensor<3x3x64x64xf32>, %arg3: tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32> {
    %0 = tensor.empty() : tensor<4x64x80x64xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %1 = scf.for %arg4 = %c0 to %c4 step %c1 iter_args(%arg5 = %0) -> (tensor<4x64x80x64xf32>) {
      %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg2 : tensor<4x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%0 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
      %3 = linalg.softmax dimension(3) ins(%2 : tensor<4x64x80x64xf32>) outs(%0 : tensor<4x64x80x64xf32>) -> tensor<4x64x80x64xf32>
      %extracted_slice = tensor.extract_slice %3[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
      %extracted_slice_0 = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 66, 82, 64] [1, 1, 1, 1] : tensor<4x66x82x64xf32> to tensor<1x66x82x64xf32>
      %extracted_slice_1 = tensor.extract_slice %arg2[0, 0, 0, 0] [3, 3, 64, 64] [1, 1, 1, 1] : tensor<3x3x64x64xf32> to tensor<3x3x64x64xf32>
      %extracted_slice_2 = tensor.extract_slice %0[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
      %4 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x66x82x64xf32>, tensor<3x3x64x64xf32>) outs(%extracted_slice_2 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
      %extracted_slice_3 = tensor.extract_slice %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<4x64x80x64xf32> to tensor<1x64x80x64xf32>
      %5 = linalg.add ins(%extracted_slice, %4 : tensor<1x64x80x64xf32>, tensor<1x64x80x64xf32>) outs(%extracted_slice_3 : tensor<1x64x80x64xf32>) -> tensor<1x64x80x64xf32>
      %inserted_slice = tensor.insert_slice %5 into %arg5[%arg4, 0, 0, 0] [1, 64, 80, 64] [1, 1, 1, 1] : tensor<1x64x80x64xf32> into tensor<4x64x80x64xf32>
      scf.yield %inserted_slice : tensor<4x64x80x64xf32>
    }
    return %1 : tensor<4x64x80x64xf32>
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["linalg.add"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %1 = transform.structured.match ops{["linalg.softmax"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %tiled_linalg_op, %loops = transform.structured.tile_using_for %0[1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
      %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1 into %loops : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
      %fused_op_0, %new_containing_op_1 = transform.structured.fuse_into_containing_op %2 into %loops : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
      %3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      transform.apply_dce to %3 : !transform.any_op
      transform.yield 
    }
  }
}

It seems that Tiling softmax cannot be performed, which is obviously wrong. As far as I know, softmax can be Tiling.

Is there an solution?

Thanks,

sheen

This was discussed on the round table at EuroLLVM and we’re working on it.

Technically, linalg.softmax doesn’t yet have a formal definition. We reached consensus that the semantics of all named ops need to be encoded and not just a product of its lowering, so we’ll have to fist determine what is the formal semantics of this op to be able to implement tiling for it.

But ultimately, softmax is a composition of element-wise operations (exp, div), reduction operations (max, sum) and broadcast operations. All of these operations can be tiled, but their tiling decisions may be different, depending on the memory model of the underlying system.

One simple technique is to isolate the reduction in the middle and fuse the pre-reduction operations with the previous producer and the post-reduction operations with the following consumer. But that still leaves the middle part to handle, so we can’t just “break linalg.softmax into tile versions of themselves inside loops like element-wise operations”.

Our proposed approach (at the round table) is to lower linalg.softmax to the canonical sequence of linalg operations and then tile those, and fuse them with producers/consumers as needed and leave the reduction/broadcast to be handled by special passes / transforms that know more about the target.

1 Like

Thanks for your response. As you said, when I debugged it, I found that it currently only supports ops in LinalgStructuredOps.td. Your answer seems to confirm this conclusion. But I still have a question: In my opinion, although linalg.softmax encapsulates its internal calculations, it can be considered as a tile for the whole, and the dimension I specified is 3, while in tiling and fusion it only works on its 0、1、2 dimensions. Even considering its internal calculation logic, it should be able to be tiled and fused for the whole. So, I don’t understand why linalg.softmax doesn’t have a formal definition yet. Please tell me, where can I get the content(like videos or detailed text) of the round table discussion on linalg.softmax?