Hey!
I’d like to check my understanding of the semantic of memref.reinterpret_cast
.
TL;DR Is the semantic of %dst = memref.reinterpret_cast %src to offset:[%offset], …
a noop with respect to the produced base address (base + offset
)?
If I am correct we have a bug in the memref expansion pass (expand-strided-metadata
) that I recently added (although the lowering to LLVM makes the final codegen do what we want in this case).
Context
In my mind memref.reinterpret_cast
s are something that we use to make the type system happy and in particular, in terms of codegen, it would be fine for %dst = memref.reinterpret_cast %src to offset: [%offset]
to always lower the related memref descriptors to:
dst.base = src.base
dst.offset = src.offset
assert(dst.offset == %offset && “reinterpreting to something invalid”)
In practice the lowering propagates the passed information (at least the LLVM lowering) like so:
dst.base = src.base
# propagate %offset, we’ve been told this is what this memref points to:
dst.offset = %offset
Which is fine but may mean something different in terms of intended semantic.
Indeed:
- Is my understanding correct that
reinterpret_cast
s arenoop
s with respect tobase + offset
? (I.e., if%offset
is not equal tosrc.offset
at runtime, then all bets are off), or - Do
reinterpret_cast
s actually support producing a differentbase + offset
memref thansrc
? Put differently, is it supported to have%offset != src.offset
at runtime?
About the (possible) bug in expand-strides-metadata
expand-strided-metadata
currently generates something like:
# Break the %src memref into its different components
%base, %offset, … = memref.extract_strided_metadata %src
# Do something finicky!
%final = memref.reinterpret_cast %base to offset: [%offset], …
To rebuild a memref that represents %src
.
By definition the %base
returned by memref.extract_strided_metadata
has an offset of 0 (the actual offset of %src
is in the result named %offset
).
Currently this “works” because, the lowering of memref.reinterpret_cast
to LLVM effectively propagates the offset we asked for, thus we get %src
.
I don’t think that’s the intended semantic of memref.reinterpret_cast
though, i.e., I believe that technically we could lower memref.reinterpret_cast
with some runtime assertions (who’s up to build a mlir-ubsan? :)) and we should flag this reinterpret_cast
as invalid.
If I’m right regarding the semantic, I’ll post another RFC with a proposed fix to re-building a memref from the results of memref.extract_strided_metadata
operations.