Allocating workgroup memory for use by mlir-cuda-runner

Hi all,

This may be a beginner doubt, apologies if it sounds too naive. There are two parts to my questions.
1.) Suppose IR has been lowered from some high level dialect to the GPU dialect. The gpu.launch op has been created by the lowering path and it wants to use the workgroup memory(shared memory in CUDA). Suppose to model this allocation of buffers in workgroup memory at some high level dialect, A pass inserts an alloca op with memory attribution of 3. After lowering to the GPU dialect, It ends up being in the body of gpu.launch op. Now when such a code is run through mlir-cuda-runner it simply crashes. I read that the correct way to use workgroup memory is to attach something like workgroup(%buffer : memref<32xf32>) to gpu.func op. Is there a way to use an allocation inside the gpu.launch op? If there is not a direct way then is there a pass to actually hoist such alloca ops to the gpu.func when the kernel is outlined and mark them as workgroup memory and also update all its uses?

2.) Is there a way to allocate memory in the GPU global memory for the operands/outputs of an operation that is to be performed on the GPU. There is a gpu.host_register op which actually translates to cuMemHostRegister in case of CUDA, which may not be the best choice. Instead is there something or someway that would map allocations to cudaMalloc?

Thanks!

Can you please paste the crash trace with mlir-cuda-runner and alloca? Is it an assert with a message or a failed cast?

I think it is a failed assertion when converting to NVVM IR.
crash.txt (8.9 KB)

A pass that hoists alloca operations does not exist but I would welcome such an addition. The reason why the gpu dialect uses function level annotations is that the shared memory actually (in the CUDA case) is modeled as a global value.

Also, note that the workgroup memory annotations are static memory, i.e., there is one buffer for the function and not one for each activation of the function.

Regarding the crash, it likely is the lowering of the alloca operation. While ‘alloca’ should not occur in a gpu function, it should not crash.

@csigg is currently revamping the gpu dialect (also for async support) and adding actual allocation is on the roadmap there. However, we currently have no lowering of alloc<...,3> to a gpu.alloc of some form.

Again, contributions in that area are very welcome.

@herhut Thanks for the reply!

The reason why I was thinking it should be present inside the kernel(or the top level gpu.func op) was, It would resemble what is present in CUDA(declaring workgroup memory via __shared__ inside the kernel). Anyways, since it is not implemented currently, It would make more sense to implement a pass for hoisting rather than trying to make the allocations happen inside the kernels(I have little to no Idea how much of changes will be required for the later and also if it makes sense).

Nice to know that there is someone already on it. If it is not too rude to ask, will the actual allocation feature be available in the near future(say a few weeks), is there some estimated time?

Thanks!

It isn’t clear to me why an “alloca” maps obviously to the workgroup memory? It seems conceptually private to the activation and maps more naturally in Cuda to the “local memory” (or a promotion to registers) than the “shared memory”.

I think @navdeepkk also mentions the memory space on the type being allocated as 3 - that might be associated with “shared memory”. It isn’t natural for this memory space to have explicit dealloc’s I assume - so they have to be alloca’s. I’m not sure though - it depends on what the contracts are with respect to the different memory spaces.

@navdeepkk I assume you meant “memory space” when you said “memory attribution”.

Yes memory space was what I meant when by memory attribution.

As @bondhugula also mentioned, That since shared memory doesn’t have explicit dealloc’s(in CUDA), it seemed more appropriate to model shared memory at a some higher level(relative to GPU) dialect using alloca with a memory space of 3. If there is something that is more suited, can you please share it?

Sure, we can use memory space to say that we want this to be in the shared memory, but it isn’t obvious to me that it creates a memory that is “shareable”, in the sense that the pointer returned by std.alloc would be the same in every activation.
I don’t think llvm alloca would do this, and I’m not sure LLVM alloca would work with this memory space and the PTX backend?

Why not model it simply as a global?
This is what Cuda Clang does when emitting an LLVM module I believe, so I suspect that it would match the PTX backend expectations.

+1 to this. I think this is what Cuda clang (and nvcc) does. On the SPIR-V path the workgroup memory (which is same shared memory) the memref with memory space 3 is lowered to workgroup memory by promoting it to a spv.globalVariable with an spv.module scope.

Re: alloca vs alloc, it is a bit strange. There is no deallocations w.r.t to shared memory, so that matches alloca semantics, but at the same time, the alloca is meant to be “deallocated” when it goes out of scope and that memory available for use with other objects (like a true stack object). Thats not the case for shared memory. Its scope is the entire kernel execution. On the SPIR-V side we just handle alloc for now.

Thanks for pointing that out. I get what you are saying. If modeled as a global then they will have to be hoisted as workgroup memory, when lowered to the GPU. As @MaheshRavishankar mentions the spv dialect has spv.globalVariable, What is the right op to model this as a global if the spv path is not taken?

If modeled as a global(not sure which op), then this hoisting will no longer be for the alloca, but for whatever op is used as global. Will there be any other use case where alloca ops may require hoisting?

When we lower a gpu.func that has some workgroup memory attribution to LLVM, we create an LLVM global for the attributed memory. You can also look at the tests to get an idea of what the lowering looks like.

As @mehdi_amini mentioned, typically alloca uses the frame of the activation to allocate memory. The hoisting here would be required as for the shared address space, the memory is not in the frame but allocated once for all activations (or threads). This needs to be done before launching the kernel, hence it cannot be done inside of a function (or even in a function that is called from a kernel (top-level) function).

Thanks! The pointers are very useful. I now see that the workgroup memory is converted to a llvm.mlir.global.

What I get form this is that there is no wire/pass that can convert from something like.

module{
func @main(){
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %22, %arg7 = %12, %arg8 = %c1_5) threads(%arg3, %arg4, %arg5) in (%arg9 = %c1_5, %arg10 = %c1_5, %arg11 = %c1_5) {
%m0 = alloca() : memref<16x16xf32, 3>
}

return
}
}

To

module attributes {gpu.container_module} {
func @main() {
“gpu.launch_func”(%22, %12, %c1_5, %c1_5, %c1_5, %c1_5, %0, %1, %2) {kernel = @main_kernel::@main_kernel} : (index, index, index, index, index, index) → ()
return
gpu.module @main_kernel {
gpu.func @main_kernel() workgroup(%m0 : memref<16x16xf32, 3>) kernel {

}
}
}
}

Is that correct?

I don’t read this as the same things though: the alloca will allocate the memory for each of the thread, the workgroup allocate it once and provide the same memory to each of the thread.

1 Like