[RFC] Simplifying Dynamic Shared Memory Access in GPU

I’ve been exploring the usage of GPU’s dynamic shared memory in MLIR. While the gpu.launch Op allows setting the size via the dynamic_shared_memory_size argument, accessing this dynamic shared memory is convoluted. I’m proposing a new Op, gpu.dynamic.shared.memory that aims to simplify the utilization of dynamic shared memory.

Background

Dynamic shared memory is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.

LLVM’s NVPTX backend kind of abuses 0-sized arrays as a global object to access dynamic shared memory. Pointer arithmetic is done as usual via getelementptr. Here is a CUDA example compiled with clang.

MLIR is great for hiding this kind of implementation detail, so I would like to leverage that.

Current Challenges in MLIR

Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:

  1. memref.global 0-sized array NVPTX expects
  2. dynamic_shared_memory_size Set the size of dynamic shared memory
  3. memref.get_global Access the global symbol
  4. reinterpret_cast and subview Many OPs for pointer arithmetic

I’ve found these steps to be somewhat convoluted, especially the 1st and 2nd steps. Additionally, the 4th step requires a combination of memref.reinterpret_cast and multiple memref.subview OPs, which can be challenging to manage. The shared memory is small memory space. The compiler always knows the size of shared memory, so it can knows the offsets. I believe using simple offset value should be sufficient.

// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @main() {  
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...) 
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>    
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }

Proposal

I propose a new Op gpu.dynamic.shared.memory that simplifies 1st, 3rd, and 4th steps. The implementation would be very straightforward. Below is an example IR code showcasing its usage:

func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
        // Step 1: Obtain shared memory directly
        %7 = gpu.dynamic.shared.memory offset = 73728 : memref<64x64xf16, 3>
        %8 = gpu.dynamic.shared.memory offset = 77824 : memref<64x64xf16, 3>
        
        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}

Op Features

  1. No more 0-Sized Global Symbol Generation: The lowering will hide 1st and 3rd steps.
  2. Simplified Shared Memory Access: No need for reinterpret_cast or subview. The offset argument will be sufficient.
  3. Compile-time Bound Check: The Op verifier checks dynamic_shared_memory_size < offset if they are compile-time constants.
  4. Runtime-time Bound Check: We can add {dynamicBoundCheck} attribute that checks dynamic_shared_memory_size < offset on the runtime. This is optional and definitely adds overhead, but it can be beneficial for debugging.

Example from IREE

IREE is another notable example that leverages dynamic shared memory in GPU programming. It has a custom approach, see the IR snippet below (see full example). It’s worth noting that this method involves the use of memref.alloc, which, from a semantic perspective, is somewhat unconventional since it implies dynamic memory allocation when memref.alloc is executed. In fact, dynamic memory allocation is done by the host.

This way makes hard to re-use existing shared memory since it requires always memref.alloc, I guess there is another pass or code that helps that, or one can generate cast.

%0 = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
memref.store %f0, %0[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
memref.dealloc %0 : memref<1xf32, #gpu.address_space<workgroup>>

To make this IR lowers properly to llvm, IREE has two passes.

1st Pass involves the following operations:

  1. Traverses memref.alloc() with address space address_space<workgroup>
  2. Generating llvm.global objects.
  3. Creating llvm.addressof to access the global memory.
  4. Performing address arithmetic using llvm.gep.

2nd Pass

  • Removes the memref.dealloc #gpu.address_space<workgroup> operation. This is necessary because shared memory, once allocated, cannot be deallocated in the typical sense.

IREE can get rid of these two passes and could take advantage of new Op gpu.dynamic.shared.memory.

Update

I have updated the proposal. The offsets can be multi dimensional and be constant or dynamic SSA values and mixed.

Let’s write the program above with that:

func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %7 = gpu.dynamic.shared.memory [%i,0,0] : memref<64x64xf16, 3>
        %i2 = arith.addi %i, %c1
        %8 = gpu.dynamic.shared.memory [%i2,0,0] : memref<64x64xf16, 3>
        
        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}

Offsets are used generate the base pointer with llvm.getelementpointer:

gpu.dynamic.shared.memory [%i,0,0] -> %basePtr = llvm.getelementptr %0[%i,0,0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<64 x array<64 x f32>>
.... build-memref-descriptor(%basePtr)
1 Like

@MaheshRavishankar @qed

Do you need to subview? Can’t you reinterpret_cast directly into the same memref without subview?

I see the subview as a structured / safe way of getting the resulting memref, while the reinterpret_cast is just bypassing all this.

If so, the op you’re introducing is just fusing the “get_global” and “reinterpret_cast” as far as I can tell?

It might be achieved with a single reinterpret_cast or even using memref.view. My point here is that there is a need for additional memref OPs, but what we need is an offset.

Yes it does fuse them and it goes a step further. The operation will also:

  1. Generate a global symbol.
    a) Could set the correct alignment that enables vectorized ld/st
  2. Perform bound checks at compile-time, and maybe even runtime bound checks.

The Op isn’t large, but it will provide a canonical way and simplify the use of dynamic shared memory.

This proposal and the original example looks dedicated to the cuda programming model.
OpenCL for example, it’s valid to have multiple instances of memref.alloc <#gpu.address_space<workgroup>> with dynamic sizes in the host side and pointers can be passed as kernel arguments. Further lowering just need to generate host side clSetKernelArg calls with correct information before launching the kernel (likely to be just sizes).

Ideally, we should assume that gpu.launch can use memref.alloc <#gpu.address_space<workgroup>> of

  • static size allocations inside the kernel and
  • dynamic size allocations from the host.

There could be a potential new analysis later in the pipeline to fold all the dynamic sized #gpu.address_space<workgroup> allocations and split them with offset to the one large allocation (and a global symbol, if needed). And then the total size of the merged allocation can be used as the dynamic size parameter for the cuda invocation.

I completely understand your concern, and I think it’s not just limited to cuda. It’s important to note that the behavior of the new Op (gpu.dynamic.shared.memory) is a result of how the hardware, rather than being specific to the programming model. The shared cannot be allocated dynamically inside the device. Kernel launch needs to allocate dynamic shared memory, and the kernel then uses this allocated memory on the device.

Additionally, it’s worth noting that within the launch Op in the GPU dialect has the dynamic_shared_memory_size attribute. This feature makes the gpu dialect an excellent choice for implementing dynamic shared memory operations. I think having this new Op in gpu dialect is a good fit.

Your suggestion sounds like a great idea, I think we should have it for any other gpu programming model. However, it’s orthogonal to the proposed Op. The compiler doesn’t always have visibility of the host code generation, and it doesn’t necessarily have to. In this case, it isn’t possible to detect memref.alloc <#gpu.address_space<workgroup>>, for example if the compiler is only generates a device code. In this case, the proposed Op comes very hand.

FWIW, static size might not be sufficient for static shared memory, especially for NVIDIA. When the size exceeds 48kb, it becomes necessary to allocate it dynamically. Additionally, the appropriate attribute needs to be set using the driver. See for mlir’s CUDA runtime code.

I’m generally +1 on removing the inherited LLVM quirk of using a zero-sized global and reading past the known static boundary. This is quite confusing and may get in the way of boundary checkers, for example.

I agree with @mehdi_amini that the proposed op looks like a fused version of get_global and reinterpret_cast, which is (1) incomplete compared to the reinterpret_cast and (2) omits the symbol name. Currently, the proposed op only accepts a static offset, but I think this will be considered a limitation very quickly. The next thing we will want is to accept a dynamic offset as SSA value. And after that, we will want to compute which offset we need for a (potentially dynamically-shaped) column-major matrix with stride along the innermost dimension (e.g., because it’s the real part of a complex number). At this point, we will have implemented the functionality of both reinterpret_cast and subview

My suggestion would be to ditch the global symbol at memref level and instead have a gpu.default_dynamic_shared_memory : () -> memref<?xi8> that we can then memref.view and memref.subview appropriately.

The proposal also mentions excessively complex address computation due to multiple subviews. Assuming we can’t or don’t want to fold a chain of subviews into one (which is possible in theory, although may be hurdled by rank-reducing behavior), we still have the memref-expansion part of the lowering that replaces view/subview with explicit address arithmetic, which we can then optimize with canonicalizations and CSE. It was specifically introduced to simplify address computation in MLIR and not rely on LLVM being able to do it later. Isn’t that sufficient?

cc @qcolombet

1 Like

It’s pretty similar on the other GPU hardwares. Shared memory sizes need to be decided before the launch at the latest so that HW sequencer can decide the warp can be scheduled or not.
OpenCL relies on driver and compiler to workout the launch time dynamic sized local buffers. I suppose underneath implementation is actually using a single allocation and relative offsets per each buffers. So, it’s matter of the programming model.
I thought dynamic_shared_memory_size is also the part of the proposal, didn’t notice it’s already there.

Can’t we just have memref.alloc(%s) inside the kernel, from dynamic_shared_memory_size? that can served as above gpu.default_dynamic_shared_memory hopefuly.

Thanks everyone for your insight!

I’m glad to know that we’re aligned in hiding the zero-sized array.

I’ve actually implemented with dynamic offsets as SSA values, which are then passed to llvm.getelementpointer.

The shared memory is quite small, and the compiler knows its boundaries. Even if the offset is an SSA value, the compiler has knowledge of these boundaries. Given that, I am not sure we need to re-implement reinterpret_cast or subview. I cannot think of a use-case.

Your perspective is entirely valid, and it’s interesting to note that IREE adopts a similar approach, as I mentioned earlier. The limitation here is that it requires compilation of host and device together. Also, it is semantically strange to me. It implies dynamic memory allocation when memref.alloc is executed. In fact, dynamic memory allocation is done by the host.

1 Like

That seems pretty clean to me, what is the downside of this @grypp ?

Sorry, I wasn’t clear about the memref.alloc

module {
  func.func @test(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>, %arg3: i32) {
    %c1 = arith.constant 1 : index
    gpu.launch blocks(%arg4, %arg5, %arg6) in (%arg10 = %c1, %arg11 = %c1, %arg12 = %c1) threads(%arg7, %arg8, %arg9) in (%arg13 = %c1, %arg14 = %c1, %arg15 = %c1) dynamic_shared_memory_size %arg3 {
      %0 = arith.index_cast %arg3 : i32 to index
      %alloc = memref.alloc(%0) : memref<?xf32, #gpu.address_space<workgroup>>
      %1 = "test.use.shared.memory"(%alloc) : (memref<?xf32, #gpu.address_space<workgroup>>) -> index
      gpu.terminator
    }
    return
  }
}

This is valid to have but I’m not sure if this still involves the host part compilation or not.

I also agree to remove the global symbol but basically hope to find a way to address this without introducing a low-level operations in the gpu dialect. My understanding on the gpu dialect is, it’s representing gpu compute APIs in common, not the GPU hardware interface. Having an operation under the assumption that there’s only single shared memory instance, could be helpful to canonicalize before lowering to the CUDA target but I’m worrying the earlier lowering pipeline targetting gpu get to assume that’s the standard for all GPU.
I’m still not 100% sure about the scope, maybe it’s totally fine to have them as optional and just need to avoid being misused in the lowering/transforms.

fyi,
Technically, shared memory allocation is done by the gpu hw - in the part differently called as controller/firmware/sequencer per each vendor, according to the parameters passed by the host. Also still it’s fair to say it’s host allocated because host determines the allocation sizes, so I think it’s just subtle difference in the viewpoint and it doesn’t really matter because everything is opaquely done under the compute APIs.

It’s not clear to me from your description how that would work: what if we have many alloc and control-flow: how do we compute the need shared memory? Knowing that it’ll have to be some sort of symbolic expression if the memref aren’t statically shaped. This also does not work after outlining: the gpu dialect is also modeling GPU kernels without the host involved.
(I don’t know how you connect all this to an OpenCL problem yet)

Sorry, my previous example was actually not even valid for the OpenCL.
This one should be fine.

module {
  func.func @test(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>, %arg3: index) {
    %c1 = arith.constant 1 : index
    %alloc = memref.alloc(%arg3) : memref<?xf32, #gpu.address_space<workgroup>>
    gpu.launch blocks(%arg4, %arg5, %arg6) in (%arg10 = %c1, %arg11 = %c1, %arg12 = %c1) threads(%arg7, %arg8, %arg9) in (%arg13 = %c1, %arg14 = %c1, %arg15 = %c1) {
      %0 = "test.use.shared.memory"(%alloc) : (memref<?xf32, #gpu.address_space<workgroup>>) -> index
      gpu.terminator
    }
    return
  }
}

and after outlining,

module attributes {gpu.container_module} {
  func.func @test(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>, %arg3: index) {
    %c1 = arith.constant 1 : index
    %alloc = memref.alloc(%arg3) : memref<?xf32, #gpu.address_space<workgroup>>
    gpu.launch_func  @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)  args(%alloc : memref<?xf32, #gpu.address_space<workgroup>>)
    return
  }
  gpu.module @test_kernel {
    gpu.func @test_kernel(%arg0: memref<?xf32, #gpu.address_space<workgroup>>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
      %0 = "test.use.shared.memory"(%arg0) : (memref<?xf32, #gpu.address_space<workgroup>>) -> index
      gpu.return
    }
  }
}

OpenCL kernel doesn’t need to know the total size but just got pointers as the arguments to the kernel. example (another example in OpenCLC, not the same as the above IR)
The implementation details may differ per each implementation but it could lower to the dynamic offset to the single allocation.
After the compilation and before the kernel launch, OpenCL driver collects the size information from clSetKernelArg calls to all the arguments and calculates the total size and pass it to the GPU hw on launch.

I’m still trying to understand the CUDA requirement better. Now, I think the outlining might be a good place to handle this, i.e., when gpu.launch specifies dynamic_shared_memory_size, create gpu.default_dynamic_shared_memory and corresponding views only in gpu.func.

In that way, we can still handle the single dynamic shared memory with offsets while multiple dynamic shared memory remains valid.

This is an great start. I’m onboard avoiding unnecessary additions if memref.view could solve everything. Nevertheless, I also appreciate the elegance of having a simplification Op. I experimented with dynamic SSA values to achieve the same outcome using original proposal and new one. To my mind, using memref.view seems a bit more convoluted compared to my original Op. Let’s decide on one of them, and I’ll proceed with the pull request.

Original idea

Simple, no need view, subview and etc. The verifier of could check out-of-bounds of bounds on gpu.launch Op when visible.

%tc = arith.constant 2 : index 
scf.for %i = %c0 to %tc step %c1 {
  // Iteration-0 accesses [0:16384] and Iteration-1 accesses [16384:32768]
  %lhsSlice = gpu.dynamic.shared.memory [%i,0,0] : memref<128x64xf16,3>			
  %v1 = arith.addi %i, %tc : index
  // Iteration-0 accesses [32768:49152] and Iteration-1 accesses [49152:65536]
  %rhsSlice = gpu.dynamic.shared.memory [%v1,0,0] : memref<64x128xf16,3>		
  ...
  use(%lhsSlice, %rhsSlice)
}

@ftynse proposal

Requires calculating offsets in bytes, and it also relies on using memref.view. I’m uncertain whether the verifier is intended to check for out-of-bounds issues on the gpu.launch Op when it’s visible.

%dynamicMem = gpu.dynamic.shared.memory : memref<i8,3>
%tc = arith.constant 2 : index 
%s1 = arith.constant 16384 : index 
%s2 = arith.muli %tc, %s1 : index 
scf.for %i = %c0 to %tc step %c1 {
  %v0 = arith.muli %s1, %i : index
  %v1 = arith.addi %v0, %s2 : index
  // Iteration-0 accesses [0:16384] and Iteration-1 accesses [16384:32768]  
  %lhsSlice = memref.view %dynamicMem[%v0][] : memref<0xi8, 3> to memref<128x64xf16, 3>  
  // Iteration-0 accesses [32768:49152] and Iteration-1 accesses [49152:65536]
  %rhsSlice = memref.view %dynamicMem[%v1][] : memref<0xi8, 3> to memref<64x128xf16, 3>
  ...
  use(%lhsSlice, %rhsSlice)
}

@jungpark, your example can indeed function well. However, for it to work, the compiler always needs gpu.launch to perform the outlining.

btw, how would you do reusing the allocated memory for different data types, like going from memref<?xf32> to memref<?xf16>?

Just to comment on that, in general you need a subview because reinterpret_cast will not compute the offset properly.

reinterpret_cast sets the offset field in the new memref descriptor, without taking into account what the input memref had.

E.g.,

%mymemref = memref<..., [offset: 10]>
%newmemref = memref.reinterpret_cast to offset: [26], ... : memref<..., [offset:10]> -> memref<..., offset: [26]> 

In this snippet, newmemref will point at mymemref.base + 26, whereas if the intent was to move the offset by 26 (i.e., what a subview would have done: mymemref.base + mymemref.offset + 26), we’d have to redo the subview work “by hand”.
E.g.,

%mymemref = memref<..., [offset: 10]>
%base_offset = memref.get_offset %mymemref : index
%new_offset = arith.add %base_offset, 26
%newmemref = memref.reinterpret_cast to offset: [%new_offset], ... : memref<..., [offset:10]> -> memref<..., offset: [26]> 

The bottom line is eliminating the subview won’t make the code easier.

Now, in practice if the offset is known at compile time we could avoid the subview, but I don’t believe it is a good thing to rely on. (In particular, it may not be safe to assume that the input shared pointer as an offset of 0.)

I have updated the proposal, see the **Update ** section above. TLDR: The offsets can be multi dimensional and be constant or dynamic SSA values and mixed.

I think there are two different goals in this RFC:

  1. Remove LLVM’s weird quirk with this 0-sized global variable
  2. Make memref manipulations more elegant/easier to approach at the gpu dialect level

For #1 we all agree that this is needed and several solutions have been proposed. I, personally, like the modeling of shared pointers through arguments to the kernels. I feels that it eliminates the need for gpu.default_dynamic_shared_memory all together, but I may be missing something.

For #2, while the goal may be legit, I think it has an implication on the whole gpu dialect, not just this particular operation. That being said, we may not have to even think about it if gpu.default_dynamic_shared_memory is deemed unnecessary. I.e., let’s maybe delay the conversation for #2 until we actually have a need for that.

Shared memory is represented as a global in PTX, thus it’s possible to get the pointer from device functions, e.g.:

int* __device__ getMem() {
  extern __shared__ int tmp[];
  return tmp;
}

So, passing a ptr as kernel parameter is semantically incorrect in NVIDIA (I think AMD too). I think this is not the case for Intel, see Intel’s workaround for SYCL lowering to CUDA:
https://intel.github.io/llvm-docs/design/CompilerAndRuntimeDesign.html#local-memory-support
I’m +1 on the op, as it cover 2/3 cases.

1 Like

I’m bit confused, I thought gpu-kernel-outlining which lowers gpu.launch op to gpu.launch_func and gpu.func. Or do you want to consume gpu.launch op without outlining? I was trying to answer to the @mehdi_amini 's comment that it’s not going to work with outlining, with correct example this time.

Now I’m assuming either version of the gpu.dynamic.shared.memory is required to avoid 0-sized global symbol, when the compute API doesn’t natively support the shared pointer passing via kernel argument. (adding such support in MLIR is still technically possible but I’m not sure how much work is needed.)

So, I wonder if we can at least defer it to the slightly later stage, specifically from the gpu.func, in this way, we can keep both options. (hoping cuda also add support for shared pointer in the kernel argument in the future)

  1. In case, we have host part of the IR, dynamic memref.allocs are to be placed in the host part.
    In the outlining pass,
    1. When dynamic_shared_memory_size is provided, fold host memref.alloc’s to the new gpu.dynamic.shared.memory(s).
    2. Without dynamic_shared_memory_size, keep the allocation and pass them as kernel argument .
  1. Stand alone gpu.fun op is fine to assume either case. i.e., either gpu.dynamic.shared.memory is provided or dynamic sized memref are provided as the arguments.

  2. Everything in gpu.launch op - I feel this is still too early stage to lower the dynamic shared details. gpu.launch op still abstracts the interface between host and GPU, so I think they needs to be opaque but there might be different view on this.

As I mentioned, it depends on the new feature for the case 1-1.
In 1-2 (OpenCL like) case, it’d be view over the kernel argument.

I’d appreciate @grypp to kick up this discussion, trying not to miss what I can contribute but I might misunderstand something, please feel free to correct any.