Understanding the semantics of reinterpret_cast

I’d like to understand the semantics of reinterpret_cast, and currently I’m struggling over the following example, which I’ve formatted as a lit test. In summary, what I’d like to express is:

  • Define a 4xi32 alloc
  • Fill it with some data so that it looks like %alloc_0 = [0, 1, 2, 3]
  • Reinterpret-cast it into a new memref that has an offset of 1 (effectively, it should be a view of alloc_0 like %alloc_1 = [1, 2, 3] ).
  • Load index 1 of alloc_1
  • Confirm the result is 2.
// RUN: mlir-opt %s -pass-pipeline=" \
// RUN:  builtin.module(lower-affine, \
// RUN:                 normalize-memrefs, \
// RUN:                 finalize-memref-to-llvm,  \
// RUN:                 func.func(convert-scf-to-cf, convert-arith-to-llvm), \
// RUN:                 convert-func-to-llvm, \
// RUN:                 reconcile-unrealized-casts)" \
// RUN: | mlir-cpu-runner -e main -entry-point-result=i32 > %t
// RUN: FileCheck %s < %t

// CHECK: 2
func.func @main() -> i32  {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %c0_i32 = arith.constant 0 : i32
  %c1_i32 = arith.constant 1 : i32
  %c2_i32 = arith.constant 2 : i32
  %c3_i32 = arith.constant 3 : i32

  %alloc_0 = memref.alloc() : memref<4xi32>
  affine.store %c0_i32, %alloc_0[%c0] : memref<4xi32>
  affine.store %c1_i32, %alloc_0[%c1] : memref<4xi32>
  affine.store %c2_i32, %alloc_0[%c2] : memref<4xi32>
  affine.store %c3_i32, %alloc_0[%c3] : memref<4xi32>

  %alloc_1 = memref.reinterpret_cast \
           %alloc_0 to offset: [1], sizes: [4], strides: [1] \
           : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>>

  %arg0 = arith.constant 1: index
  %1 = affine.load %alloc_1[%arg0] : memref<4xi32, strided<[1], offset: 1>>

  return %1 : i32

When I run this, I get the following error, I assume, from the verifier:

error: expected result type with size = 4 instead of 5 in dim = 0
  %alloc_1 = memref.reinterpret_cast 
  %alloc_0 to offset: [1], sizes: [4], strides: [1] 
  : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>>

If I change to an offset of zero the above compiles and runs without a problem, outputting 1, as expected, because the reinterpret_cast is a no-op. I have also tried a variety of other settings for the sizes and result memref type, and they all produce similar errors.

For example, I thought maybe the output memref type should be 3xi32, since the offset means there are only three i32’s in its view. That produces this error:

error: expected result type with size = 3 instead of 4 in dim = 0
%alloc_1 = memref.reinterpret_cast 
  %alloc_0 to offset: [1], sizes: [3], strides: [1] 
  : memref<4xi32> to memref<3xi32, strided<[1], offset: 1>>

Similar problems occur when setting sizes : [4] while the output type is 3xi32 and vice versa.

So now I’m at a loss. It seems like I just completely misunderstand the semantics of this operation, or else there is a bug. For evidence of the latter, I noticed that if I leave sizes: [4] and change the output memref size to something weird, like 9xi32, it errs with expected result type with size = 4 instead of 9 in dim = 0, and ditto for a 3xi32, it says "instead of 3". But for BOTH 4xi32 and 5xi32 it thinks the actual dimension is 5 (instead of 5).

Could someone help me understand this situation?

Let me also add some context as to why I’m looking at this.

I am working on a project involving lowering some TOSA model with (static dimensions throughout) to (effectively) arith and func, where the input to the func is allowed to be a memref, but otherwise all the memrefs are removed. Then we have an existing translate tool that emits Verilog, which we use to do other fun stuff. This has worked with some simple TOSA models, but we want to do some more complex models that have convolution layers. Obstacles to lowering one such model (specifically, affine-scalrep taking days to run) led me and my colleagues to write a custom affine-scalrep pass (similar to what Maksim wrote in this CIRCT PR). In brief, our pass starts from the function body, fully unrolls the first for loop, and then (before moving on to to the next for loop) forwards stores to loads in one of three situations:

  • The source is the function argument (a memref with static dims)
  • The source is an alloc (in which case the relevant store op has its (constant) value forwarded to the load)
  • The source is the result of a reinterpret_cast of an alloc in one of the two previous cases

The reinterpret_cast ops are introduced by lowering tensor reshapes/subviews (via the --expand-strided-metadata pass) that show up in the original convolution ops.

So in addition to helping me understand this acute confusion around reinterpret_cast, I’d love to hear any advice about how else to tackle this. For example, would it make more sense to work with the subviews and reshapes directly, or else to try to lower reinterpret_cast even further before emitting Verilog? I think lowering to LLVM might be too onerous? Or else, since we heavily rely on the existing affine utility functions for extracting “constified” indexes to the stores/loads, is there a means to get that same constified information in a lower dialect, say, where the reinterpret cast is replaced by a few extra steps of arithmetic, but the constified-value-extraction can infer the constant values through the arithmetic ops?

Have you tried running your example without the normalize-memref pass? It is possible that this pass just does not handle well reinterpret_cast.

It appears that it does attempt to handle the reinterpret_cast and it shouldn’t. The “normalize-memref” pass was introduced to hoist (affine) layout maps from the type to the operation. So it goes on and changes the result type of the reinterpret_cast to something the operation does not expect because it specifically produces a memref with a strided layout.

Note that strided layouts can be converted to an affine representation of the form “offset + d0*stride0 + d1*stride1 + …” that the “normalize-memref” pass will use to estimate the linear size of the underlying data. Hence, the normalize pass will change the type to be “memref<1+original-size x i32>” when the stride is set to 1. The verifier of the reinterpret cast op will construct the expected type form the sizes provided in the op itself. You can see this by only running the two first passes from your example. This gives something like:

foo.mlir:17:14: error: expected result type with size = 4 instead of 5 in dim = 0
  %alloc_1 = memref.reinterpret_cast
foo.mlir:17:14: note: see current operation: %9 = "memref.reinterpret_cast"(%8) <{operand_segment_sizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 1>, static_sizes = array<i64: 4>, static_strides = array<i64: 1>}> : (memref<4xi32>) -> memref<5xi32>

where the note shows you what the operation looks like after the cast.

I cannot reproduce this. I expectedly get off-by-one numbers in both cases. I hypothesize that in one of the cases you could have forgotten to update both the sizes: in the op and the resulting type. Having the following (note the difference):

  %alloc_1 = memref.reinterpret_cast
           %alloc_0 to offset: [1], sizes: [4], strides: [1]
           : memref<4xi32> to memref<5xi32, strided<[1], offset: 1>>

is invalid and is reported even without running any passes.

There’s a brand new mem2reg somewhere (RFC: Generic mem2reg in MLIR, ping @Moxinilian ). Have you tried using that instead?

Why do you need to run strided expansion that early? This is a part of the lowering process and it loses information. I have been a long term opponent to the introduction of reinterpret_cast because it creates a lot of opportunity for misuse, adds unstructured aliasing, and is overall poorly analyzable. The status quo is that it is used in lowering, but should be avoided as much as possible otherwise.

Do you even need affine operations? Or just the canonicalization capability of affine maps that gives you more aggressive constant folding than regular arithmetics.

It may be possible to extract at a lower level provided that canonicalization / constant-folding / CSE is effective enough. However, MLIR philosophy is more about doing things at a higher level of abstraction. I would recommend exploring if what you need can be done with some cheap analyses and foldings on strided memrefs. The load/store operations won’t necessarily have constant address offsets in them, but they may be easy to find by stepping through use-def chains. Conceptually, strided memrefs are hyperrectangular data spaces, that is a subset of affine which is much easier to handle and probably more suitable for hardware anyway.

Removing normalize-memref does indeed fix the problem! Thanks.

I’ll go explore those alternatives now to see if I can avoid reinterpret_cast altogether by doing higher level analyses.

I also tried using --mem2reg but it doesn’t seem to do anything. It appears it doesn’t handle affine directly (“This pass only supports unstructured control-flow”) and my feeble attempts to lower to cf before running it still don’t seem to produce anything. I will look around for more examples of what mem2reg can do.

Hey! So mem2reg on memref types right now is implemented so it only operates on scalar memrefs, it is the final step to get rid of a memref allocation. If you have high dimensional memrefs, it might be interesting to apply SROA first, that will break larger memrefs into scalars. But this is not intended as general store-to-load forwarding (because that would require more complex aliasing analysis in the general case), so if you do not intend to remove allocas entirely it probably will not be useful.

You are also right that it does not support structured control flow right now (simply because I do not have time to do that). It’s not an impossible thing to implement though, if somebody is interested :slight_smile: (it will require some interface design)

Because SCF is not implemented, I did not bother implementing the interfaces for the affine dialect (only memref), so in any case in its current state it would unfortunately not have worked for your use case. Again, this is just a bandwidth problem though, the mem2reg infra is fairly new!

1 Like