I’ve been exploring the usage of GPU’s dynamic shared memory in MLIR. While the gpu.launch
Op allows setting the size via the dynamic_shared_memory_size
argument, accessing this dynamic shared memory is convoluted. I’m proposing a new Op, gpu.dynamic.shared.memory
that aims to simplify the utilization of dynamic shared memory.
Background
Dynamic shared memory is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.
LLVM’s NVPTX backend kind of abuses 0-sized arrays as a global object to access dynamic shared memory. Pointer arithmetic is done as usual via getelementptr
. Here is a CUDA example compiled with clang.
MLIR is great for hiding this kind of implementation detail, so I would like to leverage that.
Current Challenges in MLIR
Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
memref.global
0-sized array NVPTX expectsdynamic_shared_memory_size
Set the size of dynamic shared memorymemref.get_global
Access the global symbolreinterpret_cast and subview
Many OPs for pointer arithmetic
I’ve found these steps to be somewhat convoluted, especially the 1st and 2nd steps. Additionally, the 4th step requires a combination of memref.reinterpret_cast
and multiple memref.subview
OPs, which can be challenging to manage. The shared memory is small memory space. The compiler always knows the size of shared memory, so it can knows the offsets. I believe using simple offset value should be sufficient.
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
// Step 2. Allocate shared memory
gpu.launch blocks(...) threads(...)
dynamic_shared_memory_size %c10000 {
// Step 3. Access the global object
%shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
// Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
%4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
%5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
%6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
%7 = memref.subview %6[0, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
%8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
// Step.5 Use
"test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
gpu.terminator
}
Proposal
I propose a new Op gpu.dynamic.shared.memory
that simplifies 1st, 3rd, and 4th steps. The implementation would be very straightforward. Below is an example IR code showcasing its usage:
func.func @main() {
gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
// Step 1: Obtain shared memory directly
%7 = gpu.dynamic.shared.memory offset = 73728 : memref<64x64xf16, 3>
%8 = gpu.dynamic.shared.memory offset = 77824 : memref<64x64xf16, 3>
// Step 2: Utilize the shared memory
"test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
}
}
Op Features
- No more 0-Sized Global Symbol Generation: The lowering will hide 1st and 3rd steps.
- Simplified Shared Memory Access: No need for
reinterpret_cast
orsubview
. Theoffset
argument will be sufficient. - Compile-time Bound Check: The Op verifier checks
dynamic_shared_memory_size < offset
if they are compile-time constants. - Runtime-time Bound Check: We can add
{dynamicBoundCheck}
attribute that checksdynamic_shared_memory_size < offset
on the runtime. This is optional and definitely adds overhead, but it can be beneficial for debugging.
Example from IREE
IREE is another notable example that leverages dynamic shared memory in GPU programming. It has a custom approach, see the IR snippet below (see full example). It’s worth noting that this method involves the use of memref.alloc
, which, from a semantic perspective, is somewhat unconventional since it implies dynamic memory allocation when memref.alloc
is executed. In fact, dynamic memory allocation is done by the host.
This way makes hard to re-use existing shared memory since it requires always memref.alloc
, I guess there is another pass or code that helps that, or one can generate cast.
%0 = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
memref.store %f0, %0[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
memref.dealloc %0 : memref<1xf32, #gpu.address_space<workgroup>>
To make this IR lowers properly to llvm, IREE has two passes.
1st Pass involves the following operations:
- Traverses
memref.alloc()
with address spaceaddress_space<workgroup>
- Generating
llvm.global
objects. - Creating
llvm.addressof
to access the global memory. - Performing address arithmetic using
llvm.gep
.
- Removes the
memref.dealloc #gpu.address_space<workgroup>
operation. This is necessary because shared memory, once allocated, cannot be deallocated in the typical sense.
IREE can get rid of these two passes and could take advantage of new Op gpu.dynamic.shared.memory
.
Update
I have updated the proposal. The offsets can be multi dimensional and be constant or dynamic SSA values and mixed.
Let’s write the program above with that:
func.func @main() {
gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
%i = arith.constant 18 : index
// Step 1: Obtain shared memory directly
%7 = gpu.dynamic.shared.memory [%i,0,0] : memref<64x64xf16, 3>
%i2 = arith.addi %i, %c1
%8 = gpu.dynamic.shared.memory [%i2,0,0] : memref<64x64xf16, 3>
// Step 2: Utilize the shared memory
"test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
}
}
Offsets are used generate the base pointer with llvm.getelementpointer
:
gpu.dynamic.shared.memory [%i,0,0] -> %basePtr = llvm.getelementptr %0[%i,0,0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<64 x array<64 x f32>>
.... build-memref-descriptor(%basePtr)