Hi!
I’ve faced with test failues in my application after update to latest MLIR and did some investigation. I’ve found several issues with memref.extract_aligned_pointer_as_index
, memref.extract_strided_metadata
and memref.view
operations, which caused those test failues, and I’d like to discuss them.
Here is my scenario. I have an IR with multiple constant 1D memrefs:
%cstA = memref.get_global @A: memref<4xi64>
%cstB = memref.get_global @B: memref<4xi64>
I’d like to get data address of them and pass it to external function. I’m using extract_aligned_pointer_as_index
and extract_strided_metadata
for that to be more generic in terms of offsets:
%cstA_base_ptr = memref.extract_aligned_pointer_as_index %cstA : memref<4xi64> -> index
%base, %cstA_offset, %size, %stride =
memref.extract_strided_metadata %cstA :
memref<4xi64>, index, index, index
%cstA_final_ptr = index.add %cstA_base_ptr, %cstA_offset
func.call @foo(%cstA_final_ptr)
// same for cstB
Finally, I have a separate pass, which merges all constants into single one of type i8
and replace separate memref.get_global
with memref.view
into it:
%merged_cst = memref.get_global @merged: memref<64xi8>
%cstA = memref.view %merged_cst [%offset_0][] : memref<64xi8> to memref<4xi64>
%cstB = memref.view %merged_cst [%offset_32][] : memref<64xi8> to memref<4xi64>
// code with extract_aligned_pointer_as_index and func.call
But, what’s happening between this passes? memref.extract_strided_metadata
has folder, which replaces its results with constant values, when offsets and/or strides are compile-time known constants. This is true for original IR and the folding mechanism removes extract_strided_metadata
and pointer arithmetic with offset and leaves only:
%cstA_base_ptr = memref.extract_aligned_pointer_as_index %cstA : memref<4xi64> -> index
func.call @foo(%cstA_base_ptr)
After the constant merging pass I’ve got the following IR:
%merged_cst = memref.get_global @merged: memref<64xi8>
%cstA = memref.view %merged_cst [%offset_0][] : memref<64xi8> to memref<4xi64>
%cstB = memref.view %merged_cst [%offset_32][] : memref<64xi8> to memref<4xi64>
%cstA_base_ptr = memref.extract_aligned_pointer_as_index %cstA : memref<4xi64> -> index
func.call @foo(%cstA_base_ptr)
%cstB_base_ptr = memref.extract_aligned_pointer_as_index %cstB : memref<4xi64> -> index
func.call @foo(%cstB_base_ptr)
Now, the memref.extract_aligned_pointer_as_index
is called on the result of memref.view
and by its definition it returns not the pointer to the start of particular constant data, but the base pointer to the %merged_cst
. This becomes more visible after -expand-strided-metadata
pass, which just swaps memref.extract_aligned_pointer_as_index
and memref.view
and both function calls get the same pointer.
Thus, the first issue, from my point of view - the memref.extract_strided_metadata
shouldn’t perform implicit constant folding via common folding mechinsm, since the IR might be modified by further passes and this modifications might introduce different offsets and strides.
The second issue - memref.view
generates memref result memref with empty layout map, which is equvivalent to compact strides and zero offset. But that’s not actually true, because we do have offset from the base buffer. It should look like:
%cstB = memref.view %merged_cst [%offset_32][]
: memref<64xi8> to memref<4xi64, strided<[1], offset: 4>>
// ^^^^^^^^^
Funny thing - why it works for me prior to update to latest upstream. I has a --convert-linalg-to-llvm
pass in my pipeline, which was called prior to -expand-strided-metadata
. This pass internally used finalize-memref-to-llvm
patterns, so the this part were converted to LLVM dialect here and not in -expand-strided-metadata
/--finalize-memref-to-llvm
calls below in the pipeline. Accidentelly, finalize-memref-to-llvm
patterns lowers pair of memref.view -> memref.extract_aligned_pointer_as_index
into correct code and I’ve got valid pointers in the function calls. After update to latest upstream I’ve removed --convert-linalg-to-llvm
call and the issue appreared.
All this discoveries brings me to more generic question - why we don’t have a simpler way to get pointer to the first element of memref
value? It would be more simpler in the cases, when we want to pass memref
data into some external function without need to deal with base pointers and offsets.
Sorry for such long topic