Vector.transfer split pattern bug?

Hi,

Looks like I’ve faced a bug in VectorTransferFullPartialRewriter pattern with VectorTransferSplit::VectorTransfer option.

I’ve applied CodegenStrategy:

  • tiling with tile sizes {3, 0, 0}
  • vectorization with mlir::vector::VectorTransferSplit::VectorTransfer option

To such matmul:

func @MatMul(%arg0: memref<5x1xf32>, %arg1: memref<1x6xf32>, %arg2: memref<5x6xf32>) {
    %cst = constant 0.000000e+00 : f32
    linalg.fill(%arg2, %cst) : memref<5x6xf32>, f32 
    linalg.matmul ins(%arg0, %arg1 : memref<5x1xf32>, memref<1x6xf32>) outs(%arg2 : memref<5x6xf32>)
    return
  }

Which produces such MLIR code at the end of unrolled loop (this code is produced with createScopedFullPartialVectorTransferWrite function). Note that non-vector memref is casted to vector memref (memref<3x6xf32> to memref<vector<3x6xf32>>):

....
%120 = xor %117, %true : i1
      scf.if %120 {
        %121 = vector.type_cast %9 : memref<3x6xf32> to memref<vector<3x6xf32>>
        %122 = memref.load %121[] : memref<vector<3x6xf32>>
        %123 = vector.type_cast %6 : memref<3xvector<6xf32>> to memref<vector<3x6xf32>>
        memref.store %122, %123[] : memref<vector<3x6xf32>>
....

Issue here is that memref<3x6xf32> contains continuous sequence of floats, while memref<vector<3x6xf32>> is lowered into [3 x <6 x float>] LLVM-IR type and LLVM vector is padded to 8 (nearest power of 2).
memref.load %121[] : memref<vector<3x6xf32>> is lowered into load [3 x <6 x float>], [3 x <6 x float>]* %702, which loads floats wrongly (data layouts under memref<3x6xf32> and memref<vector<3x6xf32>> are different).
This is LLVM-IR of first basic block of this scf.if %120.

  %695 = extractvalue { float*, float*, i64, [2 x i64], [2 x i64] } %108, 0
  %696 = bitcast float* %695 to [3 x <6 x float>]*
  %697 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } undef, [3 x <6 x float>]* %696, 0
  %698 = extractvalue { float*, float*, i64, [2 x i64], [2 x i64] } %108, 1
  %699 = bitcast float* %698 to [3 x <6 x float>]*
  %700 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %697, [3 x <6 x float>]* %699, 1
  %701 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %700, i64 0, 2
  %702 = extractvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %701, 1
  %703 = load [3 x <6 x float>], [3 x <6 x float>]* %702, align 32
  %704 = extractvalue { <6 x float>*, <6 x float>*, i64, [1 x i64], [1 x i64] } %84, 0
  %705 = bitcast <6 x float>* %704 to [3 x <6 x float>]*
  %706 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } undef, [3 x <6 x float>]* %705, 0
  %707 = extractvalue { <6 x float>*, <6 x float>*, i64, [1 x i64], [1 x i64] } %84, 1
  %708 = bitcast <6 x float>* %707 to [3 x <6 x float>]*
  %709 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %706, [3 x <6 x float>]* %708, 1
  %710 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %709, i64 0, 2
  %711 = extractvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %710, 1
  store [3 x <6 x float>] %703, [3 x <6 x float>]* %711, align 32

And this is related x86-64 assembly. (%rdi) is address of source memref and (%rax) is address of destination memref (address offsets should be different to properly load floats (0x20 should be 0x18, 0x30 should be 0x28, e.t.c.)):

vmovaps (%rdi),%xmm0
mov    0x10(%rdi),%rcx
vmovaps 0x20(%rdi),%xmm1
mov    0x30(%rdi),%rdx
vmovaps 0x40(%rdi),%xmm2
mov    0x50(%rdi),%rsi
movl   $0x0,(%rdi)
mov    %rsi,0x50(%rax)
vmovaps %xmm2,0x40(%rax)
mov    %rdx,0x30(%rax)
vmovaps %xmm1,0x20(%rax)
mov    %rcx,0x10(%rax)
vmovaps %xmm0,(%rax)

What is proper way to solve this issue - fix VectorTransferFullPartialRewriter pattern or fix vector.type_cast lowering?

@Buyduck thanks for reporting !

You’re right that vector.type_cast is unsafe in this context, for the reasons you mention, and the lowering jumps the gun here.

Basically. the following pattern:

b.create<memref::StoreOp>(
  loc, vector,
    b.create<vector::TypeCastOp>(
      loc, MemRefType::get({}, vector.getType()), alloc));

is unsafe should be replaced by a vector::TransferWriteOp.

Same story for:

Value load = b.create<memref::LoadOp>(
  loc, b.create<vector::TypeCastOp>(
    loc, MemRefType::get({}, xferOp.vector().getType()), alloc));

which is unsafe should be replaced by a vector::TransferReadOp.

Additionally, we should update the definition / verifier of vector::TypeCastOp: it is generally illegal to vector.type_cast scalar memref without DataLayout considerations:

%VA = vector.type_cast %A : memref<5x4x3xf32> to memref<vector<5x4x3xf32>>

Instead, the intended use case is to typecast between memref of vector types:

%VA = vector.type_cast %A : memref<5x4xvector<3xf32>> to memref<vector<5x4x3xf32>>

and delegate the 1-D case to lower levels that can perform bitcast / look at the DataLayout upport that was recently introduced in MLIR.

It seems from your wording that you’d be open to fixing yourself ?

If so, please send a most welcome patch, otherwise lmk and I’ll fix.

Thanks!

1 Like

Thank you, I would be glad to fix it myself :S

1 Like

@nicolasvasilache I’ve started implementation and faced some more issues, so could you please do it yourself :S