Hi all,
I’m trying to augment MemRefToSPIRV
to lower memref.view
to corresponding `spirv ops.
However, it’s unclear to me what the correct lowering should look like.
For example, for the below toy IR
module attributes {} {
func.func @foo(%arg0: memref<8192xi8, #spirv.storage_class<CrossWorkgroup>>) {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][] : memref<8192xi8, #spirv.storage_class<CrossWorkgroup>> to memref<32x32xf32, #spirv.storage_class<CrossWorkgroup>>
%load = memref.load %view[%c0, %c0]
return
}
}
I would assume the resulting IR (after convert-memref-to-spirv
) to look something like
module attributes {} {
func.func @foo(%arg0: memref<8192xi8, #spirv.storage_class<CrossWorkgroup>>) {
%0 = unrealized_conversion_cast %arg0 : memref<8192xi8, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.struct<(!spirv.array<2048 x i32>)>, CrossWorkgroup>
%c0 = arith.constant 0 : index
%index_cast = unrealized_conversion_cast %c0 : index to i32
// some spirv pointer arithmetic op(s)
%base_ptr = <some_spirv_pointer_arithmetic_op> %0 : !spirv.ptr<!spirv.struct<(!spirv.array<2048 x i32>)>, CrossWorkgroup> to !spirv.ptr<!spirv.struct<(!spirv.array<1024 x i32>)>, CrossWorkgroup>
%load = spirv.AccessChain %base_ptr[%index_cast, %index_cast]
return
}
}
I’m wondering if it’s even possible to do something like in the above IR. Or maybe I need to convert memref.view
to some other dialect which can then be transformed to spirv
?
I’d appreciate any help! Thanks!
Right now we don’t support lowering memref.subview
ops directly; we expect they are folded away into the original memref allocation using the FoldMemRefAliasOps pass.
Though this is more historical–the initial support was mostly focusing on Vulkan compute, where we cannot freely have !spirv.struct
containing pointers as local variables. We haven’t scaled out to support other lowering paths yet. I see what you are doing in the above is for OpenCL actually; so it’s doable there.
@antiagainst
Right, I’m trying to do it for OpenCL. However, I notice that folding memref.subview
is supported by FoldMemRefAliasOps
but not memref.view
.
Do you think it’ll be better to modify that pass to fold memref.view
s as well or is it better to add conversion to the MemRefToSPIRV
pass?
memref.view
is something new. Looking at its definition, it actually supports element type casting. That means it won’t be able to fold into the original memref allocation, which could have a different element type. So the way to go would be enhance conversion to SPIR-V to support the memref descriptor approach for OpenCL. Concretely it would mean 1) add conversion path to !spirv.ptr<!spirv.struct<.., !spirv.ptr, ..>>>
in SPIR-V type converter, gated on having Kernel
capability, and 2) add necessary lowering patterns for various memref ops to handle such cases, also gated on Kernel
capability.
Apologies for the dumb questions since I’m relatively new to MLIR. Regarding the two points that you made, I have questions.
- add conversion path to
!spirv.ptr<!spirv.struct<.., !spirv.ptr, ..>>>
- Is this to provide a conversion of the
memref
function argument to the corresponding spirv
argument?
- If so, I think that is already being done by
FuncToSPIRV
and so, I guess I may not need to do that after all?
- add necessary lowering patterns for various
memref
ops
- Right, this is what I’m looking at mostly and since I’m not super well-versed with
spirv
, I’m not quite sure what the lowering path of memref.view
to spirv
would look like. More specifically, I’m not sure if there is a way of going from !spirv.ptr<!spirv.struct<(!spirv.array<2048 x i32>)>, CrossWorkgroup>
(the original consolidated buffer) to
!spirv.ptr<!spirv.struct<(!spirv.array<1024 x i32>)>, CrossWorkgroup>
(the view into the consolidated buffer) and that potentially should be followed by a spirv.bitcast
?
Some 2c here:
We are using MLIR OpenCL SPIR-V path internally and while we are still representing memrefs as bare pointers, we have a pass to handle memref.subview
(including dynamic shapes) https://github.com/numba/numba-mlir/blob/main/mlir/lib/Conversion/GpuToGpuRuntime.cpp#L713
We are basically doing all memref offsets calculations manually, and then reconstruct it via memref.reinterpret_cast
. It doesn’t support memref.view
specifically, but it should be straightforward to add.
We should probably upstream this (among million of other things).
@antiagainst I have a basic question about the semantics of spirv.AccessChain
or spirv.PtrAccessChain
.
Say I have
%0 = !spirv.ptr<!spirv.array<2048xi32>, CrossWorkgroup>
, is it valid to do the following?
// get base pointer for view; return type is !spirv.ptr<i32, CrossWorkgroup>
%view_base_ptr = spirv.AccessChain %0[0] : !spirv.ptr<!spirv.array<2048xi32>, CrossWorkgroup>, i32
%ptr_to_elt = spirv.AccessChain %view_base_ptr[0] : !spirv.ptr<i32, CrossWorkgroup>, i32
%bitcast = spirv.Bitcast %ptr_to_elt : !spirv.ptr<i32, CrossWorkgroup> -> !spirv.ptr<f32, CrossWorkgroup>
%load_elt = spirv.Load "CrossWorkgroup" %ptr_to_elt : f32
If not, given a !spirv.ptr<spirv.array<2048xi32>, CrossWorkgroup
, how do I offset into into it and get !spirv.ptr<spirv.array<1024xi32>, CrossWorkgroup
, i.e., translate memref.view
to corresponding spirv
ops?