While working on GPU Codegen I ran into a problem with dealing with memref. When generating code for a matmul with promotion I get the code below. The Memref created by the promoted allocation is used by both linalg.copy and vector.transfer_read. When I lower the linalg.copy I want to load/store chunks of vector<4xi32> as this going to be the most efficient memory access on most GPUs. The transfer_read however need to keep the type vector<8x32xi8> to potentially map it to the GPU native Cooperative Matrix type if it is supported.
%A = alloc() : memref<128x32xi8, 3>
linalg.copy(%0, %A) : memref<128x32xi8, #map0>, memref<128x32xi8, #map1, 3>
%S = subview %A[%arg6, 0] [64, 32] [1, 1] : memref<128x32xi8, #map1, 3> to memref<64x32xi8, #map1, 3>
%10 = vector.transfer_read %S[%c0, %c0], %c0_i8 : memref<64x32xi8, #map1, 3>, vector<8x32xi8>
In order to generate good code I’m writing a transformation that would change the code to the following code.
%A = alloc() : memref<128x2<4xi32>, 3>
%22 = load %arg0[%20, %21] : memref<4096x256xvector<4xi32>>
store %22, %A[%17, %19] : memref<128x2xvector<4xi32>, 3>
`%OriginalTypeA = memref_cast %A : memref<128x2xvector<4xi32>, 3> to memref<128x32xi8, 3>`
`%10 = vector.transfer_read %OriginalTypeA[%c0, %c0], %c0_i8 : memref<64x32xi8, #map1, 3>, vector<8x32xi8>`
However I need to be able to reinterpret the memref with a different shape and element type. This is analogue to the existing vector.type_cast used for vectorization on CPU but I need a more relaxed operation as I need to change the element type and the lowest dimension of the shape won’t always match the vector size.
One alternative I tried was to lower the copy to transfer_read/transfer_write of vector<16xi8> and lower that to bitcast i8* to vector<4xi32>* + load <4xi32> but SPIR-V for Vulkan doesn’t allow casting pointers in general cases. It is supported for some operations like cooperative matrix load which s why being able to insert the memref cast explicitly is very useful.
Is there any existing operation I can use to be able to do this kind of transformation? Would adding such operation make sense? I realized this may not be useful for non GPU targets. Do you see any alternative solution?
I realize there are existing cast operations and there is a RFC in flight: [RFC][Standard] Memref cast ops but none of it seems to match what I need.
This is related to the current MLIR code review: ⚙ D85058 [mlir][vector] Add experimental memref cast operation.