Is synchronization missed for RAW dependent ops during thread distribution inside iree?

I am learning the code logic of IREE for cuda backend. Here is one question about synchronization during distribution.
LLVMGPUTileAndDistribute responses to tile and distribute to warp/threads inside a workgroup (a.k.a block in the terminology of cuda gpu)

For the case of matmul(see below), there is one linalg.fill & linalg.matmul. linalg.fill initializes the matrix output buffer. And they both tile and distribute to several warps.
There is RAW (also WAW) for linalg.fill and linalg.matmul. The dependency should be kept after distribution to warps.
After reading the dump log and part of iree code, It seems the tiling and distribution is done for every op individually. And there is no try of checking and inserting synchronization (syncthreads in cuda).

I am not sure if It happens to be right (linalg.fill and linalg.matmul use the same tiling and distribution parameters, so that dependent workload are distributed to same warp), or there is some logic to make sure the synchronization is not necessary.

Thanks for any answers.

The case;

func.func @dot_1024() -> tensor<1024x1024xf32> {
  %lhs = util.unfoldable_constant dense<1.0> : tensor<1024x1024xf32>
  %rhs = util.unfoldable_constant dense<2.0> : tensor<1024x1024xf32>
  %res = "mhlo.dot"(%lhs, %rhs) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
  return %res: tensor<1024x1024xf32>
}

The command:

iree-translate --iree-input-type=mhlo --iree-hal-target-backends=cuda --iree-mlir-to-vm-bytecode-module  mhlo_dot1024.mlir -o mhlo_dot1024.linalg.vmfb  -mlir-print-ir-after-all -mlir-print-ir-before-all

IR before LLVMGPUTileAndDistribute, there is linalg.fill and linalg.matmul.

// -----// IR Dump Before LLVMGPUTileAndDistribute //----- //
func.func @dot_1024_dispatch_0() {
  %c1024 = arith.constant 1024 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<1024x1024xf32>
  memref.assume_alignment %0, 64 : memref<1024x1024xf32>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<1024x1024xf32>
  memref.assume_alignment %1, 64 : memref<1024x1024xf32>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<1024x1024xf32>
  memref.assume_alignment %2, 64 : memref<1024x1024xf32>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
  %4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_y]
  scf.for %arg0 = %3 to %c1024 step %4 {
    %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
    %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x]
    scf.for %arg1 = %5 to %c1024 step %6 {
      %7 = memref.subview %0[%arg0, 0] [32, 1024] [1, 1] : memref<1024x1024xf32> to memref<32x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
      %8 = memref.subview %1[0, %arg1] [1024, 128] [1, 1] : memref<1024x1024xf32> to memref<1024x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
      %9 = memref.subview %2[%arg0, %arg1] [32, 128] [1, 1] : memref<1024x1024xf32> to memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
      linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 32]]>} ins(%cst : f32) outs(%9 : memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 102
4 + s0 + d1)>>)
      linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 32]]>} ins(%7, %8 : memref<32x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d
1)>>, memref<1024x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>) outs(%9 : memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
    }
  }
  return
}

IR after LLVMGPUTileAndDistribute, linalg.fill and linalg.matmul is tiled and distributed to warps individually.

// -----// IR Dump After LLVMGPUTileAndDistribute //----- //
func.func @dot_1024_dispatch_0() {
  %c128 = arith.constant 128 : index
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %c1024 = arith.constant 1024 : index
  %c32 = arith.constant 32 : index
  %0 = memref.alloc() : memref<32x128xf32, 3>
  %1 = memref.alloc() : memref<32x32xf32, 3>
  %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<1024x1024xf32>
  memref.assume_alignment %2, 64 : memref<1024x1024xf32>
  %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<1024x1024xf32>
  memref.assume_alignment %3, 64 : memref<1024x1024xf32>
  %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<1024x1024xf32>
  memref.assume_alignment %4, 64 : memref<1024x1024xf32>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
  %6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_y]
  scf.for %arg0 = %5 to %c1024 step %6 {
    %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
    %8 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x]
    scf.for %arg1 = %7 to %c1024 step %8 {
      %9 = memref.subview %2[%arg0, 0] [32, 1024] [1, 1] : memref<1024x1024xf32> to memref<32x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
      %10 = memref.subview %3[0, %arg1] [1024, 128] [1, 1] : memref<1024x1024xf32> to memref<1024x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
      %11 = memref.subview %4[%arg0, %arg1] [32, 128] [1, 1] : memref<1024x1024xf32> to memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
      %12 = gpu.thread_id  x
      %13 = gpu.thread_id  y
      %14 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%13]
      scf.for %arg2 = %14 to %c32 step %c32 {
        %15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%12]
        scf.for %arg3 = %15 to %c128 step %c128 {
          %16 = memref.subview %11[%arg2, %arg3] [4, 4] [1, 1] : memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
          linalg.fill {__internal_linalg_transform__ = "vectorize", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 32]]>} ins(%cst : f32) outs(%16 : memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
        }
      }
      scf.for %arg2 = %c0 to %c1024 step %c32 {
        %15 = memref.subview %9[0, %arg2] [32, 32] [1, 1] : memref<32x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<32x32xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
        %16 = memref.subview %10[%arg2, 0] [32, 128] [1, 1] : memref<1024x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
        gpu.barrier
        memref.copy %15, %1 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<32x32xf32, 3>
        memref.copy %16, %0 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<32x128xf32, 3>
        gpu.barrier
        %17 = gpu.thread_id  x
        %18 = gpu.thread_id  y
        %19 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%18]
        scf.for %arg3 = %19 to %c32 step %c32 {
          %20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%17]
          scf.for %arg4 = %20 to %c128 step %c128 {
            %21 = memref.subview %1[%arg3, 0] [4, 32] [1, 1] : memref<32x32xf32, 3> to memref<4x32xf32, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>
            %22 = memref.subview %0[0, %arg4] [32, 4] [1, 1] : memref<32x128xf32, 3> to memref<32x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>
            %23 = memref.subview %11[%arg3, %arg4] [4, 4] [1, 1] : memref<32x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
            linalg.matmul {__internal_linalg_transform__ = "vectorize", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 32]]>} ins(%21, %22 : memref<4x32xf32, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, memref<32x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>) outs(%23 : memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
          }
        }
      }
    }
  }
  return
}

@ThomasRaoux

@chongxing, you are correct we don’t insert synchronization and don’t explicitly check that we don’t need it. By construction we are applying the same distribution to all the ops here, so that we know that each thread/warp access the same data for each ops.

This relies on assumptions on how dispatch regions are created and how the configuration is set.
Ideally we would do tile and fuse instead of tiling operations individually and that would enforce the correctness of the transformation more strongly but in general we want to tile the reduction dimension first to be able to promote subviews to shared memory.
We could potentially add a verifier to make sure the assumption is always correct however this hasn’t been a strong concern so far. That being said extra comments would probably be useful to understand the logic and I’ll look at documenting this part more.

I hope this clarifies things.

@ThomasRaoux Thanks for explanations. It’s very clear.

Could I ask another question? You have achieved a lot on optimizing matmul for cuda. And have you thought of implementing efficient conv2d for cuda inside IREE? (This drives me to learn a bit more of linalg, and then thinking of above question.)

I believe you know, implicit gemm is one efficient way to implement conv in GPU, which seems not trivial for linalg. Once I asked related question and got some explanations/answer on Is it possible to add parameter for indexing_maps of linalg.generic?, roughly mod and div of index mapping is not supported, as hyper-tangular is required for subview /tensor_insert /tensor_extract operations.

As I see several attractive features of IREE, I would like to rethink, whether it is feasible to implement implicit gemm for gpu inside IREE. But have no answer yet: able to extend with supporting non hyper-tangular, or some solution without breaking hyper-tangular?
Thanks.

Great to hear that it piqued your interests. We do have thoughts on supporting convolutions but haven’t dug into it as much in order to stay concentrated.

Yes we looked implicit gemm in the past and some are trying to revive it. (cc @cbate, @nicolasvasilache). I’m not sure we would need mod and div support in index mapping to be able to implement implicit gemm. I haven’t looked at it in details but implicit gemm seems similar to im2col, it is just a matter where it is applied and im2col could be applied at different levels.

Glad to hear that you are looking into IREE. Feedback is very much welcomed. Do you have a special need to non hyper rectangular shapes or is this mainly to support implicit gemm? If you need more interactive discussions about IREE feel free to drop in IREE discord:

IREE has a transformation img2col that demonstrates decomposing the conv2d into

GEMM A: input_tensor -> linalg.generic (img2col) -> reshape (m, k)
GEMM B: filter_tensor -> reshape (k, n)

Then both of those feed into linalg.matmul.

Note that the linalg.generic is effectively a pointwise operation/view into the input, but the read dimensions are not a “projected permutation”.

So feasibly you could ask whether one could get the “implicit” version by just running the linalg transform “tile and fuse on tensors” rooted at the linalg.matmul operation. Unfortunately, the reshape on the A operand will prevent you from pulling in the linalg.generic directly into the tiled matmul, because it collapses several dimensions which are not projected permutations and thus can’t be “pulled”/commuted with the linalg.generic.

There are other options then. You could directly produce the correct Ir at the block tile level, but it would be a custom rewrite and not really fit in with the rest of the tile-and-fuse sequence. Other options would be to change tiling in some manner or add a custom operation tiled op representing the img2col view. It would be cool if we could get the linalg_ext tiling interface in upstream (just asked about this in another thread today as well). @ThomasRaoux or @nicolasvasilache do you know if/when that could happen?

Thanks for invitation. I have login and would try to involve.

Good analytical perspective, for me to better understand the challenge.
Here is the IR after img2col (from the sample test)

  func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %0 = linalg.init_tensor [1, 14, 14, 3, 3, 4] : tensor<1x14x14x3x3x4xf32>
    %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x16x16x4xf32>) outs(%0 : tensor<1x14x14x3x3x4xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):
      linalg.yield %arg3 : f32
    } -> tensor<1x14x14x3x3x4xf32>
    %2 = tensor.collapse_shape %1 [[0, 1, 2], [3, 4, 5]] : tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
    %3 = tensor.collapse_shape %arg1 [[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
    %4 = tensor.collapse_shape %arg2 [[0, 1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<196x16xf32>
    %5 = linalg.matmul ins(%2, %3 : tensor<196x36xf32>, tensor<36x16xf32>) outs(%4 : tensor<196x16xf32>) -> tensor<196x16xf32>
    %6 = tensor.expand_shape %5 [[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
    return %6 : tensor<1x14x14x16xf32>
  }

Let me try to keep up with you. With regarding to “img2col view”, there will be one Op to convert from A to img2col view(to replace above linalg.generic and reshape), and we can extract subview of img2col during tiling, and finally lower it during promoting to shared memory (or during writing back to global memory for Output). Is this understanding right?

Yes, but you can’t do an “output promote to shared memory” on an operation that is operating on tensor types, so the creation of shared memory buffers would need to happen by implementing the bufferization interfaces. After this, you would write a transform to convert the op into a set of vector dialect transfer_read ops.

So that’s a way you could get it working quickly, but this is not a particularly nice solution because it’s not very generic and you can’t hook into the “tile and fuse” flow of linalg.

I’m experimenting with other solutions which wouldn’t require a separate op, but I need some more time to verify which one would mesh well with the existing pointwise fusion transforms in linalg.

OK. It’s better to be able to hook into “tile and fuse” flow of linalg, as a solution.
Not clear if it’s possible: allow the Op to hook extract_slice and some other interfaces?

This sounds very attractive. Would you try to direct pull tensor.collapse_shape and linalg.generic of img2col into matmul tile/fuse?

With regarding to img2col view, it’s “subview” is not hyper-rectangular, will this be one more issue?

[update] I have found this related rfc RFC for `TilingInterface` for tiling operations that dont fit into Linalg Structured Operation definition, this looks not an issue.

Actually we are getting closer to being able to do that thanks to bufferize.alloc_tensor. This would still require adding memory space information to the tensor type. This was discussed in the past but maybe we are now reaching a good point in time to do it. As often, initial prototyping could start with just adding an attribute on the alloc_tensor for now.

@matthias-springer

1 Like