[RFC][memref] New op to (re)materialize a memref

Hi,

TL;DR I’d like to add a new op memref.construct_strided_metadata that is the counterpart of memref.extract_strided_metadata, I.e., it materializes a memref based on the provided metadata.

Context

In expand-striped-metadata we simplify “complex” memref operations by

  1. Extracting the metadata of the source memref using memref.extract_strided_metadata
  2. Applying the effects of this complex operation to the metadata (offset, sizes, and strides) using affine.apply
  3. Rebuilding the resulting memref

E.g.,

%res = memref.subview %src, %subSizes, …

=>

%base, %offset, %sizes, %strides = memref.extract_strided_metadata(%src)
%final_sizes = %subSizes
%final_strides = <some math> %strides
%final_offset = <some math> %offset
%res = <rebuild memref descriptor> %base, %final_offset, %final_sizes, %final_strides

The problem is we currently don’t have any operation that can do #3 (<rebuild memref descriptor> in the above snippet). The current implementation uses memref.reinterpret_cast but it turns out this is not technically correct (see [MLIR] expand-strided-metadata's expansions are invalid · Issue #59896 · llvm/llvm-project · GitHub)

Proposed solution

To solve this issue, I’d like to introduce a new memref operation that would rebuild a memref descriptor from the given metadata:

memref.construct_strided_metadata base: %ZeroRanked, offset: [%offset], sizes: [%size0, …], strides: [%stride0, …] : memref<ty> -> memref<..x..xty, strided<[…], offset: …>>

The semantic is straight forward: given the base, offset, sizes, and strides, materialize a memref with these values.

The intend is to use this operation as the counterpart of memref.extract_strided_metadata.

In other words:

%base, %offset, %sizes, %strides = memref.extract_strided_metadata %src
%dst = memref.construct_strided_metadata %base, %offset, %sizes, %strides

=>

%dst = memref.cast %src

Note: To limit the footgun abilities of this operation, I’d like to enforce that the base argument must be zero ranked and have an offset of 0.

What do people think?

Cheers,
-Quentin

CC: @ftynse, @mehdi_amini, @nicolasvasilache, @pifon2a

I was not directly involved in the original design and discussions, so take my opinion as sideline, but I think this makes sense. This is a similar effect that we have for tensor.pack and tensor.unpack, and guarantees the semantics is valid, and makes it super easy to elide if one undoes the other.

I don’t get why the offset also needs to be zero.

Nit: I think construct_strided_metadata hints at metadata construction. I’d call it construct_from_strided_metadata.

That’s because I don’t want to bake any math or dropping information in the operation.
E.g., if we have a memref src with a non zero offset, then when doing
dst = construct_from_strided_metadata src, offset, ...
Is dst.offset == src.offset + new.offset? (Add semantic)
Is dst.offset == new.offset? (Override semantic, we ignored src.offset)

By adding the “offset zero” constraint on the input, the user of this operator would have to think about this. Hence, maybe fewer misuses :).

I like that!

I see, not that it has to be, but that it’s safer to be without extra info. Makes sense!

Alright, retro pedaling on the whole RFC.
I misinterpreted @pifon2a’s answer in [RFC] What is the intended semantic of `memref.reinterpret_cast`?.
Turns out reinterpret_cast is not a no-op, it really creates a new memref while overriding the metadata, which is what I was suggested in this RFC.

I.e., there is no bug, just some documentations to update (and maybe an opcode to rename!)