[RFC] What is the intended semantic of `memref.reinterpret_cast`?

Hey!

I’d like to check my understanding of the semantic of memref.reinterpret_cast.

TL;DR Is the semantic of %dst = memref.reinterpret_cast %src to offset:[%offset], … a noop with respect to the produced base address (base + offset)?

If I am correct we have a bug in the memref expansion pass (expand-strided-metadata) that I recently added (although the lowering to LLVM makes the final codegen do what we want in this case).

Context

In my mind memref.reinterpret_casts are something that we use to make the type system happy and in particular, in terms of codegen, it would be fine for %dst = memref.reinterpret_cast %src to offset: [%offset] to always lower the related memref descriptors to:

dst.base = src.base
dst.offset = src.offset
assert(dst.offset == %offset && “reinterpreting to something invalid”)

In practice the lowering propagates the passed information (at least the LLVM lowering) like so:

dst.base = src.base
# propagate %offset, we’ve been told this is what this memref points to:
dst.offset = %offset

Which is fine but may mean something different in terms of intended semantic.

Indeed:

  • Is my understanding correct that reinterpret_casts are noops with respect to base + offset? (I.e., if %offset is not equal to src.offset at runtime, then all bets are off), or
  • Do reinterpret_casts actually support producing a different base + offset memref than src? Put differently, is it supported to have %offset != src.offset at runtime?

About the (possible) bug in expand-strides-metadata

expand-strided-metadata currently generates something like:

# Break the %src memref into its different components
%base, %offset, … = memref.extract_strided_metadata %src
# Do something finicky!
%final = memref.reinterpret_cast %base to offset: [%offset], …

To rebuild a memref that represents %src.

By definition the %base returned by memref.extract_strided_metadata has an offset of 0 (the actual offset of %src is in the result named %offset).

Currently this “works” because, the lowering of memref.reinterpret_cast to LLVM effectively propagates the offset we asked for, thus we get %src.

I don’t think that’s the intended semantic of memref.reinterpret_cast though, i.e., I believe that technically we could lower memref.reinterpret_cast with some runtime assertions (who’s up to build a mlir-ubsan? :)) and we should flag this reinterpret_cast as invalid.

If I’m right regarding the semantic, I’ll post another RFC with a proposed fix to re-building a memref from the results of memref.extract_strided_metadata operations.

CC: @ftynse, @mehdi_amini, @nicolasvasilache, @pifon2a

1 Like

This is also my understanding. reinterpret_cast can be used to add or remove shape or rank dynamicity to memrefs, but nothing else. This is also what the verifier checks when it can https://github.com/llvm/llvm-project/blob/142aa1bdd1dd1db9a7fecf9d157228019c794c94/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp#L1839-L1879.

Isn’t this done by [RFC] Runtime Op Verification ?

Having such an op gets quite scary. It allows one to effectively break out of the memref structure and create arbitrary, potentially invalid memrefs. Most memref-based analyses become invalid because they can no longer assume memref well-formedness in presence of such an op.

:+1:

Ah great!

That’s a good point. At the same time, I’m in a dead end with respect to what we can express with the current memref ops.
To some extend, we already have the problem that we can produce invalid memrefs (e.g., when dynamic shapes are involved), but these may be okay because we could check them at runtime. (Whereas building a memref out of the blue would be difficult to check.)

Anyhow, back to the problem at hand.
The expand-strided-metadata pass aims at removing view like operations by explicitly model their effects on the offset, sizes, and strides, and we achieve that by using a combination of memref.extract_strided_metadata and affine.apply.

The issue is after we break down the memref, we have no way to put it back.
E.g.,

%dst = memref.subview %src, ...

=>

base, offset, ... = memref.extract_strided_metadata %src
finalOffset = affine.apply ...
...
dst = ... <-- how do we rebuild dst at this point.

@pifon2a, after holiday break ping.

You are right about the semantics. The op does not change the base pointer, but overrides offsets-sizes-strides. The users are fully allowed to shoot themselves in the foot.

Thanks for the confirmation @pifon2a .

Filed [MLIR] expand-strided-metadata's expansions are invalid · Issue #59896 · llvm/llvm-project · GitHub

Hmm, then the its verifier code for this op is incorrect. It assumes that any statically known size must match.

That’s what I thought too after you pointed it out in [RFC] What is the intended semantic of `memref.reinterpret_cast`? - #2 by ftynse, but then looking closer, I believe the verifier only checks that the offfset, sizes, and strides arguments are consistent with the resulting memref. Not that the input memref is somewhat consistent with the arguments or the result.