Dynamic MemRef Cast

LHLO dialect in TensorFlow has a dynamic_memref_cast operation that might be needed to a wider audience. The operation allows to change sizes and strides of a memref using the values computed in runtime

    %buf_transformed =
        xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
        : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>

The result of the op is a type-erased memref with [%size_X, %size_Y] shape and [%step_X, %step_Y] strides. The offset will be inherited from the input.

If this operation seems useful enough to be in StandardOps, I can start writing the code for it. Before doing that I would like to listen to your opinions about:

  • name of the operation: is dynamic_memref_cast a suitable name for this op? Should the memref_cast op be renamed?
  • should there be an argument for the offset or should we just copy the offset of the original memref?

This looks useful to have.

Perhaps not. If you keep memref_cast intact for the existing purposes, you may need a different name. dynamic doesn’t capture everything that’s happening - basically, you are reinterpreting the shape, striding, and offset information: memref_reinterpret_cast?

I think so. If you are reinterpreting shapes and strides, you might as well add offset to it. I can’t recall though if there is a way to extract out the stride and offset information (when they are dynamic) from a memref. An extract_symbol operation has been missing since the beginning (has never been needed I think) to extract out the SSA values binding to the symbols of the memref’s affine layout map (these symbols are the same as strides and offset if the map is in the stride form). The dim op gets you the shape symbols but not the others, and you’ll need them for example to copy over the offset from the memref you are casting.
affine_map<(d0, d1) [s0, s1, s2] -> (s0 * d0 + s1 * d1 + s2)>

This is an interesting point. In the most general form, symbols can be referenced in the same way that the dimensions are, i.e., by providing an index. For the case of strided memrefs, I assume having some helper functions that identify the correct symbols for their meaning should be good enough.

So far this has been modeled only at the LLVM level I think, where we have access to these symbols as we can read them out of the memref descriptor.

I am all for moving ahead with this. We only need to bikeshed a name for the operation. get_symbol, extract_symbol? And the index as an attribute should be fine as this only makes sense in the ranked case.

1 Like