Extracting dynamic offsets/strides from memref

Thanks for sending the PR, I took a deeper look at the comments on ⚙ D130849 [mlir][memref] Introduce memref.offset and memref.stride ops as well as to the abstraction that @stellaraccident has been iterating on: iree/VMVXOps.td at main · iree-org/iree · GitHub.

There is one aspect that happens at the LLVM level today that I do not see a clear path on: inserting and extracting in and out of the llvm.struct that represents the descriptor provides a structuring type that does not exist outside of the LLVM dialect.

Our objective is to get rid of that descriptor ourselves by folding it away before LLVM, for a bunch of different reasons, in particular because LLVM is not the one true thing that everyone wants to lower out to (e.g. SPIR-V, libxsmm etc).

My current analysis of the possibilities is:

  • memref.dim, memref.stride, memref.offset, memref.reinterpret_cast to create a new descriptor, some other memref abstraction to extract the base memref<T> (i.e. where ⚙ D130849 [mlir][memref] Introduce memref.offset and memref.stride ops seems like it will evolve)
  • a single memref.descriptor op that returns a 4-element result like @stellaraccident has.
  • adding a hypothetical memref_descriptor type specifically for this purpose.
  • adding hypothetical struct type that at a higher level of abstraction than LLVM with its own insert / extract.

Of the 4, the struct seems appealing in practice and would be useful for other ongoing work.

However the struct type is likely overkill atm because we don’t need multiple levels of type nesting since the descriptor is flat (not truly flat but it can be represented with 4 results and indexing with %sizes#0, %strides#1, etc…).

The multi-op aspect seems overkill too and asymmetrical and I think manipulating the following IR is going to be more beneficial for the purpose we are after right now.

// (New): Extract descriptor-related information.
%base, %offset, %sizes:2, %strides:2 = memref.descriptor %m: memref<?x?xT>

// (Existing): Construct a new memref from descriptor pieces.
%m = memref.reinterpret_cast %base to offset:[%offset] sizes: [%sizes#0, %sizes1] strides: [%strides0, %strides1]: memref<?x?xT> to memref<?x?, offset: ?, strides : [?, ?]>

I do not have a strong opinion on whether we need a specific memref_descriptor type to hide away the multi-result, I suspect it may be unnecessary.

As it stands I think I’d prefer to see @stellaraccident’s abstraction be ported to MLIR.

One potential drawback would be that it is overkill to extract all %base, %offset, %sizes:2, %strides:2 if we only need to manipulate say %strides#0 but I don’t think this is really a concern in practice for now.

The size is definitely used in a bunch of other places in codegen and memref.dim continues to make sense but we still want to discourage accessing and manipulating individual strides / offsets outside of structured memref.reshape, memref.subview ops and friends.

To that effect, having a single op to get the descriptor seems more appealing.
It would also be clear that wherever we see a non-strided memref, the descriptor does not exist and verification errors ensue.

It seems less intuitive to see a verification error on:

%0 = memref.stride %m, %c0 : some_memref_type_that_is_not_a_strided_memref

@ftynse in case he has a different reading / sensibility than me on this topic.