Hey Nicolas!
Thanks for bringing this up! Great analysis! I’m afraid I cannot be very helpful in this topic but this problem is also blocking our vectorization efforts so, please, let me know if I can help somehow.
Alternative 3: Don’t allow casts that change the memref type.
vector.type_cast
and std.view
are very powerful but they also introduce strong constraints. The latter, for example, requires the memref to be 1D and i8. I think this would have important implications all over the place, as we discussed in [1]. Any function using std.view
on input memrefs would require them to be 1D and i8, we would be losing the original shape information in the function, etc. That sounds concerning to me.
[1] Understanding the vector abstraction in MLIR - #2 by mehdi_amini
In this regard, I have a näive question that has been around for a while about what a vector
is actually modeling when it is used in a memref
. It was briefly discussed in [1] but I think it’s somehow also relevant here. Some questions that come to mind: Is %a = alloc() : memref<8xf32>
actually different from %b = alloc() : memref<2xvector<4xf32>>
? Does it mean that we are not allowed to read elements [2, 5] from %b
with a single packed vector<4xf32>
load or read the 8 elements in %b
with single packed vector<8xf32>
? Someone may answer that a vector
memref is describing contiguous elements in memory but, don’t we have memref strides to model that?
I wonder if we could simplify this problem by separating memory representation aspects (memrefs) from how elements are read or written from/to memory (scalar/vector load/store ops, gather/scatter ops, …). This would imply that vector
wouldn’t be allowed as memref element type and memory load/store operations would encode the scalar/vector information related to the memory operation. We actually have operations already that take in a “non-vector” memref and performs a vector load/store on them, where the vector information is encoded in the operation and not in the memref:
%0 = alloc() : memref<8xf32>
%1 = vector.transfer_read %0[%c0], %cst : memref<8xf32>, vector<4xf32>
%2 = affine.vector_load %0[%c0] : memref<8xf32>, vector<4xf32>
There would probably be many other implications that I’m missing but maybe it’s another alternative to consider.