[MLIR] Lower memref.view to spirv

Hi all,

I’m trying to augment MemRefToSPIRV to lower memref.view to corresponding `spirv ops.
However, it’s unclear to me what the correct lowering should look like.

For example, for the below toy IR

module attributes {} {
  func.func @foo(%arg0: memref<8192xi8, #spirv.storage_class<CrossWorkgroup>>) {
    %c0 = arith.constant 0 : index
    %view = memref.view %arg0[%c0][] : memref<8192xi8, #spirv.storage_class<CrossWorkgroup>> to memref<32x32xf32, #spirv.storage_class<CrossWorkgroup>>
    %load = memref.load %view[%c0, %c0]
    return
  }
}

I would assume the resulting IR (after convert-memref-to-spirv) to look something like

module attributes {} {
  func.func @foo(%arg0: memref<8192xi8, #spirv.storage_class<CrossWorkgroup>>) {
    %0 = unrealized_conversion_cast %arg0 : memref<8192xi8, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.struct<(!spirv.array<2048 x i32>)>, CrossWorkgroup>
    %c0 = arith.constant 0 : index
    %index_cast = unrealized_conversion_cast %c0 : index to i32
    // some spirv pointer arithmetic op(s)
    %base_ptr = <some_spirv_pointer_arithmetic_op> %0 : !spirv.ptr<!spirv.struct<(!spirv.array<2048 x i32>)>, CrossWorkgroup> to !spirv.ptr<!spirv.struct<(!spirv.array<1024 x i32>)>, CrossWorkgroup>
    %load = spirv.AccessChain %base_ptr[%index_cast, %index_cast]
    return
  }
}

I’m wondering if it’s even possible to do something like in the above IR. Or maybe I need to convert memref.view to some other dialect which can then be transformed to spirv?

I’d appreciate any help! Thanks!

@antiagainst : ^

Right now we don’t support lowering memref.subview ops directly; we expect they are folded away into the original memref allocation using the FoldMemRefAliasOps pass.

Though this is more historical–the initial support was mostly focusing on Vulkan compute, where we cannot freely have !spirv.struct containing pointers as local variables. We haven’t scaled out to support other lowering paths yet. I see what you are doing in the above is for OpenCL actually; so it’s doable there.

@antiagainst

Right, I’m trying to do it for OpenCL. However, I notice that folding memref.subview is supported by FoldMemRefAliasOps but not memref.view.

Do you think it’ll be better to modify that pass to fold memref.views as well or is it better to add conversion to the MemRefToSPIRV pass?

memref.view is something new. Looking at its definition, it actually supports element type casting. That means it won’t be able to fold into the original memref allocation, which could have a different element type. So the way to go would be enhance conversion to SPIR-V to support the memref descriptor approach for OpenCL. Concretely it would mean 1) add conversion path to !spirv.ptr<!spirv.struct<.., !spirv.ptr, ..>>> in SPIR-V type converter, gated on having Kernel capability, and 2) add necessary lowering patterns for various memref ops to handle such cases, also gated on Kernel capability.

Apologies for the dumb questions since I’m relatively new to MLIR. Regarding the two points that you made, I have questions.

  • add conversion path to !spirv.ptr<!spirv.struct<.., !spirv.ptr, ..>>>
    • Is this to provide a conversion of the memref function argument to the corresponding spirv argument?
      • If so, I think that is already being done by FuncToSPIRV and so, I guess I may not need to do that after all?
  • add necessary lowering patterns for various memref ops
    • Right, this is what I’m looking at mostly and since I’m not super well-versed with spirv, I’m not quite sure what the lowering path of memref.view to spirv would look like. More specifically, I’m not sure if there is a way of going from !spirv.ptr<!spirv.struct<(!spirv.array<2048 x i32>)>, CrossWorkgroup> (the original consolidated buffer) to
      !spirv.ptr<!spirv.struct<(!spirv.array<1024 x i32>)>, CrossWorkgroup> (the view into the consolidated buffer) and that potentially should be followed by a spirv.bitcast?

Some 2c here:

We are using MLIR OpenCL SPIR-V path internally and while we are still representing memrefs as bare pointers, we have a pass to handle memref.subview (including dynamic shapes) https://github.com/numba/numba-mlir/blob/main/mlir/lib/Conversion/GpuToGpuRuntime.cpp#L713

We are basically doing all memref offsets calculations manually, and then reconstruct it via memref.reinterpret_cast. It doesn’t support memref.view specifically, but it should be straightforward to add.

We should probably upstream this (among million of other things).

@antiagainst I have a basic question about the semantics of spirv.AccessChain or spirv.PtrAccessChain.
Say I have

%0 = !spirv.ptr<!spirv.array<2048xi32>, CrossWorkgroup>

, is it valid to do the following?

// get base pointer for view; return type is !spirv.ptr<i32, CrossWorkgroup>
%view_base_ptr = spirv.AccessChain %0[0] : !spirv.ptr<!spirv.array<2048xi32>, CrossWorkgroup>, i32
%ptr_to_elt = spirv.AccessChain %view_base_ptr[0] : !spirv.ptr<i32, CrossWorkgroup>, i32
%bitcast = spirv.Bitcast %ptr_to_elt : !spirv.ptr<i32, CrossWorkgroup> -> !spirv.ptr<f32, CrossWorkgroup>
%load_elt = spirv.Load "CrossWorkgroup" %ptr_to_elt : f32

If not, given a !spirv.ptr<spirv.array<2048xi32>, CrossWorkgroup, how do I offset into into it and get !spirv.ptr<spirv.array<1024xi32>, CrossWorkgroup, i.e., translate memref.view to corresponding spirv ops?