Extracting dynamic offsets/strides from memref

There is memref.reinterpret_cast op which allows to construct memref with arbitrary offsets/strides but there isn’t any opposite operation, i.e. to extract dynamic offset/stride from memref as ssa value. We have our custom ExtractMemrefMetadata op for this purpose and we would like to upstream it (not in this exact form) if there is any interest from MLIR community.

Proposed form:

  • %0 = memref.offset %memref : memref<...>
  • %0 = memref.stride %memref, %index : memref<...>

@chelini @nicolasvasilache were just discussing this as well

Thanks, Stella; yes, I spoke with @Hardcode84 last Friday about this because I saw something in the internal repo that was similar to what I discussed with Nicolas.

@Hardcode84 I think there is an interest in a having these operations. Do you have a patch I can play with?

https://reviews.llvm.org/D130849

1 Like

Thanks @Hardcode84 for the patch. I took a pass at reviewing it, and left some comments. Ill let others weigh in on the impact of such ops and if they have other implications. To me they seem like a good additions.
That leads to a question about how to resolve the strides in a whole program. It seems like the stride resolution should probably happen the same way tensor.dim /memref.dim resolution happens through use of ReifyRankedShapedTypeOpInterface . This interface along with this pattern allows resolving all shapes in the program by using fixed-point iteration. I suspect something similar should work for stride/offset resolution as well

Thanks for the review. We didn’t have much ‘strides resolving’ in our code, just a few simple foldings (e.g. when getStridesAndOffset returns static strides or when input is memref.reinterpret_cast) and then they go directly to llvm and access corresponding descriptor fields.

Slowly ramping back up after vacation, what is the status / consensus here ?

Slowly ramping back up after vacation, what is the status / consensus here ?

Hi, I’m still waiting for opinions/review

Thanks for sending the PR, I took a deeper look at the comments on ⚙ D130849 [mlir][memref] Introduce memref.offset and memref.stride ops as well as to the abstraction that @stellaraccident has been iterating on: iree/VMVXOps.td at main · iree-org/iree · GitHub.

There is one aspect that happens at the LLVM level today that I do not see a clear path on: inserting and extracting in and out of the llvm.struct that represents the descriptor provides a structuring type that does not exist outside of the LLVM dialect.

Our objective is to get rid of that descriptor ourselves by folding it away before LLVM, for a bunch of different reasons, in particular because LLVM is not the one true thing that everyone wants to lower out to (e.g. SPIR-V, libxsmm etc).

My current analysis of the possibilities is:

  • memref.dim, memref.stride, memref.offset, memref.reinterpret_cast to create a new descriptor, some other memref abstraction to extract the base memref<T> (i.e. where ⚙ D130849 [mlir][memref] Introduce memref.offset and memref.stride ops seems like it will evolve)
  • a single memref.descriptor op that returns a 4-element result like @stellaraccident has.
  • adding a hypothetical memref_descriptor type specifically for this purpose.
  • adding hypothetical struct type that at a higher level of abstraction than LLVM with its own insert / extract.

Of the 4, the struct seems appealing in practice and would be useful for other ongoing work.

However the struct type is likely overkill atm because we don’t need multiple levels of type nesting since the descriptor is flat (not truly flat but it can be represented with 4 results and indexing with %sizes#0, %strides#1, etc…).

The multi-op aspect seems overkill too and asymmetrical and I think manipulating the following IR is going to be more beneficial for the purpose we are after right now.

// (New): Extract descriptor-related information.
%base, %offset, %sizes:2, %strides:2 = memref.descriptor %m: memref<?x?xT>

// (Existing): Construct a new memref from descriptor pieces.
%m = memref.reinterpret_cast %base to offset:[%offset] sizes: [%sizes#0, %sizes1] strides: [%strides0, %strides1]: memref<?x?xT> to memref<?x?, offset: ?, strides : [?, ?]>

I do not have a strong opinion on whether we need a specific memref_descriptor type to hide away the multi-result, I suspect it may be unnecessary.

As it stands I think I’d prefer to see @stellaraccident’s abstraction be ported to MLIR.

One potential drawback would be that it is overkill to extract all %base, %offset, %sizes:2, %strides:2 if we only need to manipulate say %strides#0 but I don’t think this is really a concern in practice for now.

The size is definitely used in a bunch of other places in codegen and memref.dim continues to make sense but we still want to discourage accessing and manipulating individual strides / offsets outside of structured memref.reshape, memref.subview ops and friends.

To that effect, having a single op to get the descriptor seems more appealing.
It would also be clear that wherever we see a non-strided memref, the descriptor does not exist and verification errors ensue.

It seems less intuitive to see a verification error on:

%0 = memref.stride %m, %c0 : some_memref_type_that_is_not_a_strided_memref

@ftynse in case he has a different reading / sensibility than me on this topic.

IMO

  • On the option 2, I think it can lead to more ugly c++ code, some code instead of val = builder.create<StrideOp>(memref, i); it will looks like val = builder.create<DescriptorOp>(memref)[1 + i];. Also it can be very prone to breakage on api changes, (e.g. if we decide to add allocated pointer to descriptor this code will still compile but fail at runtime)
  • On the having descriptor as the first-class citizen, I do not understand what the benefit of this exactly, we still need ops to extract individual values from it (either specific to offsets/strides separately or by global index, in which case we have same issue as option 2)

It’s not clear to me why because you’re suppose to write code that looks like:

auto descriptor = builder.create<DescriptorOp>(memref);
ValueRange sizes = descriptor.getStrides();
ValueRange strides = descriptor.getSizes();

This is resilient to changes of the descriptor (or triggers build-time failures).

1 Like

+1, not a concern for me.

Nit: typo on the accessor name (don’t have edit permissions to fix. Also, maybe grant me powers too :slight_smile: )

True re “still need ops” part.

The value is to have a symmetric pair of ops to get a descriptor / create a memref from a descriptor that is more idiomatic and can’t be mistaken for something else (i.e. the reinterpret_cast is also used for other things and using it in this context is somewhat of an abuse).

Having pairs of insert / extract that can’t be mistaken for something else is a recurrent pattern everywhere.

However, a struct would be better and more general, but will take much longer to get through.

My vote is on alternative 2. and if we insist on alternative 3. then we should directly go to 4. :stuck_out_tongue:

Ok, I understand. Option 2 is fine to me then. Options 3 and 4 will be total overkill for my case, as I just need to explicitly calculate flat index for multidim memref using extracted offset/strides.

1 Like

We’re all just calculating flat indices :slight_smile:

Nice, thanks for considering/flexibility!

I don’t really have an opinion here, that’s why I haven’t commented so far.

The major reason for using struct was the impossibility of performing one-to-many type conversions. The common wisdom was that something will do SRoA and DCE at a lower level (LLVM dialect or LLVM IR) anyway, so this is not a performance issue.

AFAIK, descriptor is a concept specific to the Memref->LLVM lowering. Since you mentioned the importance of non-LLVM outfeeds, you probably don’t want to lift this concept to memref itself.

Do we practically need access to the base pointer? If not, I would consider adding !shape.strided type into the shape dialect to represent the (sizes + strides + offset) information, which is a more of a memref-level concept.

In any case, I would suggest phrasing this in terms of “strides/offsets” or “strided format” to avoid lifting the implementation detail into the memref dialect/type definition.

We already have the bulit-in tuple type, just no operations to insert/extract. Struct may come with data layout considerations that nobody wants to handle as this level.

This shouldn’t be a runtime performance problem as long as there is DCE at some later stage when the memref is transformed into a descriptor, or some clever lowering/canonicalization that avoids emitting spurious extractvalue equivalents. There is a some compile-time cost though.

I think the rationale for that was to support aliasing analysis. With all sorts of casts that I lost track of, we are likely past the point where we could still have a happy path by construction, so this may be less of a concern.

I don’t see a difference between “error: memref.get_strided_shape at file.mlir:123:45: op requires operand #0 to satisfy strided memref” and “error: memref.get_stride at file:123:45: op requires operand #0 to satisfy strided memref”, which is the error message we would get by having the AnyStridedMemRef type constraint in the op definition. There are multiple other such ops in the dialect.

1 Like

It is today, and when I was reasoning through it, I found the concept useful if lifted to a higher level, which is why I preserved the name, but I don’t care much about that. As I see it, the memref descriptor is the concrete, runtime representation of the type, and having that expressed at the higher level lets us reason about it at “higher level dialect compile time”, potentially in a large majority of cases, eliminating it entirely at that high level (and allowing it to persist as a runtime struct on targets that support that, such as llvm, when necessary).

Practically, we are lowering the memref dialect completely, when it is possible to statically resolve in this way, to 1d byte buffers at the mlir level (we have a dedicated byte buffer type and supporting ops). Since that is doing the same kind of thing as the llvm lowering but higher up, I found it reasonable to just conceptually lift the concept vs creating a new one.

At least for that use case, I do think we need the “base memref”. While this shows up in the buffer lowering I describe above, I think that purely expressed at the memref level, it should always be possible to reverse from a descriptor to an aliased equivalent of the original via a subsequent reinterpret_cast. And I think you need the base for that to be well defined. Having the descriptor of a 0d memref always fold to itself and constants makes it hold together to treat that as the base type.

I don’t have strong opinions on this or a deep history with the darker corners of the memref dialect, but lifting the descriptor and making it an explicit, high level thing, seemed to compose nicely for me when I implemented that approach for a non LLVM target (I did not also apply it LLVM but in looking at it, thought it might work well and could give us a higher level dialect approach to eliminating the need to emit a runtime struct at all in a lot of cases – potentially allowing further high level optimization in common cases than is possible now).

One of the things that has always tripped me up is that the name “memref” implies to many people a fairly low level type, but it is really modeling a fairly high level type. Not treating it in isolation, but thinking of it in terms of a more lower level “byte buffer” primitive has helped me conceptually place it better in the layer stack of concepts in my mind.

So, what is the best name for this op if we don’t want to use ‘descriptor’? Our custom op was named ‘ExtractMemrefMetadata’.

memref.extract_metadata sgtm

I was hoping to copy Stella 's impl today with memref<T> instead of the buffer result but things got in the way…

If anyone has cycles for this copy-paste+ add a test I’d welcome it otherwise I’ll get to it next week

Thanks!

One more potential issue with descriptor op, returning multiresult is that fold api cannot be used with it (because partial folding is not supported). For our specific case we need folding for statically determined offset/strides and when input is reinterpret_cast. It is still possible to do this via canonicalization patterns but it is more involved and also APIs like createAndFold won’t work and user will need to run actual canonicalization pass.