Memref cast

While working on GPU Codegen I ran into a problem with dealing with memref. When generating code for a matmul with promotion I get the code below. The Memref created by the promoted allocation is used by both linalg.copy and vector.transfer_read. When I lower the linalg.copy I want to load/store chunks of vector<4xi32> as this going to be the most efficient memory access on most GPUs. The transfer_read however need to keep the type vector<8x32xi8> to potentially map it to the GPU native Cooperative Matrix type if it is supported.

%A = alloc() : memref<128x32xi8, 3>
linalg.copy(%0, %A) : memref<128x32xi8, #map0>, memref<128x32xi8, #map1, 3>
%S = subview %A[%arg6, 0] [64, 32] [1, 1]  : memref<128x32xi8, #map1, 3> to memref<64x32xi8, #map1, 3>
%10 = vector.transfer_read %S[%c0, %c0], %c0_i8 : memref<64x32xi8, #map1, 3>, vector<8x32xi8>

In order to generate good code I’m writing a transformation that would change the code to the following code.

%A = alloc() : memref<128x2<4xi32>, 3>
%22 = load %arg0[%20, %21] : memref<4096x256xvector<4xi32>>
store %22, %A[%17, %19] : memref<128x2xvector<4xi32>, 3>
`%OriginalTypeA = memref_cast %A : memref<128x2xvector<4xi32>, 3> to memref<128x32xi8, 3>`
`%10 = vector.transfer_read %OriginalTypeA[%c0, %c0], %c0_i8 : memref<64x32xi8, #map1, 3>, vector<8x32xi8>`

However I need to be able to reinterpret the memref with a different shape and element type. This is analogue to the existing vector.type_cast used for vectorization on CPU but I need a more relaxed operation as I need to change the element type and the lowest dimension of the shape won’t always match the vector size.

One alternative I tried was to lower the copy to transfer_read/transfer_write of vector<16xi8> and lower that to bitcast i8* to vector<4xi32>* + load <4xi32> but SPIR-V for Vulkan doesn’t allow casting pointers in general cases. It is supported for some operations like cooperative matrix load which s why being able to insert the memref cast explicitly is very useful.

Is there any existing operation I can use to be able to do this kind of transformation? Would adding such operation make sense? I realized this may not be useful for non GPU targets. Do you see any alternative solution?

I realize there are existing cast operations and there is a RFC in flight: [RFC][Standard] Memref cast ops but none of it seems to match what I need.

This is related to the current MLIR code review: ⚙ D85058 [mlir][vector] Add experimental memref cast operation.

FYI: @nicolasvasilache, @aartbik, @mehdi_amini

@ThomasRaoux

Hi Thomas,

I can generalize memref_reinterpret_cast op proposed in [RFC][Standard] Memref cast ops. Currently, there is a constraint that the input/output element types should be the same, but it can be relaxed.

@ftynse, Alex, would you be fine if I remove the requirement for matching element types? In a fully-static case we can actually verify if the total number of bytes is the same. In dynamic we can insert some assertions for “debug” mode. If not, then this op can be implemented outside of Standard dialect and I can help with that as well.

Have you considered relaxing vector.type_cast to match your needs?

Not in the Standard dialect.

I think it makes sense. The question is more in which dialect it should live.

Add the “inverse-of-view” operation that transforms any memref into memref<?xi8> and then view it with a different element type. This sounds like actual reinterpret cast and may be less desirable than a more specific, vector-related transformation with stronger semantics (i.e., we can still reason about aliasing relatively easily with a special operation).

It’s a thorny question. Do we really want to be able to bitcast, e.g. a memref<?xi64> to memref<?xf64>? The fact that SPIR-V that this use case targets does not allow such behavior is evidence that we may not want to allow it either. There is also the issue of dynamic offsets and strides. @ThomasRaoux do you expect any memory reinterpretations than splitting one memref dimension into a memref dimension and a vector dimension, and eventually joining them back together? If not, I would tend to prefer a dedicated operation that does exactly this to a big hammer of bitcasting anything to anything else.

Thanks Alexander and Alex.

That would be great if this solution works for others.

I quickly considered it. After talking with @nicolasvasilache it seemed like adding a new op may be better since the vector to GPU path is still quite experimental

No, that would be the main usage I’m expecting at this point. As I mentioned, SPIR-V for Vulkan is quite limited in the kind of pointer bitcasting it allows.

That works for me. That’s kind of what I was trying to do in the review I sent. Maybe the name of the op should be more specific.

I, too, would have a slight preference trying to put this in an existing cast if it fits (since we already have vector.reshape/type_cast/shape_cast ops with somewhat specific behaviors, and I want to avoid having too many of those). But if you and Nicolas think a new op is better, I won’t object too strongly either.

This is actually trickier: you are not just adding a reinterpret cast because SPIR-V does not let you do any bitcasts. What you are adding is a “transient op” that must fold away into a special memory op (with either cooperative type or not). The SPIRV load/stores have special semantics that perform “bitcast and load/store”, the bitcasts themselves cannot exist in the wild.

As I look at the impl and the new bitcast-like op that comes out (with the restriction on the fact that it must fold away), I am wondering whether we should relax vector.transfer_read / write to allow changing the element type. This would more closely match the “bitcast + memory op” behavior of SPIR-V ops.

vector.type_cast has always been an implementation detail of lowering to CPU but the semantics of transfers is higher level.

Making the vector.transfer op change the underlying vector interpretation would require relaxing the verifier here but should otherwise be transparent. We will still have the caveat that we need memref<128x2xvector<4xi32>> and memref<128x32xi8> to have the same per-element alignment on your target HW but this generally seems unavoidable.

I am wondering whether we should relax vector.transfer_read / write to allow changing the element type.

I actually like the idea of making vector.transfer_read/write the “Swiss Army Knife” for changing the view of memory. I am in the process of making particular memory operations directly available in the vector dialect (gather/scatter, expand load/compress store, masked load/store) with the idea that we can perhaps lower vector.transfer_read/write “progressively” at some point into more primitive operations. Other view changes could perhaps follow a similar path, with higher level dialects only bothering with generating the right vector.transfer_read/write.

I experimented with this solution and that makes the transformations I’m doing much cleaner. This avoid having a temporary cast that needs to be folded in the matrix load and allow to have the right indexing in transfer_read/write from the beginning.

I’ll send a new review with the relaxed verification of vector_transfer.

It would be good to just generalize memref_reinterpret_cast. In my fork of MLIR, I actually have a memref_shape_cast op that performs such a conversion in one direction towards 1-d vectors, and the LLVM lowering for it.

std.memref_shape_cast

The "memref_shape_cast" operation converts a memref from an non-vector
    element type to another memref of a vector elemental type while not changing
    the source memref's element type. The last dimension size of the source
    dimension is divided (floor division) by the vector size to obtain the
    corresponding dimension for target memref type.

    %MV = memref_shape_cast %M : memref<64x16xf32> to memref<64x2xvector<8xf32>>
    %AV = memref_shape_cast %A : memref<?x?xf32> to memref<?x?xvector<8xf32>>

Can you send a patch for review since you already implemented this?

Independently, note that this wouldn’t solve the issue at hand: such a memref_cast cannot be legalized by itself on certain backends and needs to be folded into the memory operations.
This folding could be done at conversion time but letting the memory op “bitcast on the fly” seems like a desirable abstraction to avoid requiring (multiple) HW-specific type conversions to be able to legalize the op.

How does this work then your ?x? is 1x1 at runtime?

The way it is, that would just be undefined behavior. We might as well add an std.assume/assert (dim %A : 0, dim %A : 1) > 8 in front of the memref_shape_cast op - since 0 in a shape isn’t allowed.

Sure - I’ll do this in a couple of days.

I sent a review for it: https://reviews.llvm.org/D85244

I’m not sure if this requires doing any extra checks during lowering or if we want other restriction on the transfer ops. Let me know what you think.

This would not fit our needs, we don’t want to copy the data, just reshape the lower dimensions. The transfer operations have different sematics, namely that of a copy.