Hi,
Looks like I’ve faced a bug in VectorTransferFullPartialRewriter
pattern with VectorTransferSplit::VectorTransfer
option.
I’ve applied CodegenStrategy
:
- tiling with tile sizes
{3, 0, 0}
- vectorization with
mlir::vector::VectorTransferSplit::VectorTransfer
option
To such matmul:
func @MatMul(%arg0: memref<5x1xf32>, %arg1: memref<1x6xf32>, %arg2: memref<5x6xf32>) {
%cst = constant 0.000000e+00 : f32
linalg.fill(%arg2, %cst) : memref<5x6xf32>, f32
linalg.matmul ins(%arg0, %arg1 : memref<5x1xf32>, memref<1x6xf32>) outs(%arg2 : memref<5x6xf32>)
return
}
Which produces such MLIR code at the end of unrolled loop (this code is produced with createScopedFullPartialVectorTransferWrite
function). Note that non-vector memref is casted to vector memref (memref<3x6xf32> to memref<vector<3x6xf32>>
):
....
%120 = xor %117, %true : i1
scf.if %120 {
%121 = vector.type_cast %9 : memref<3x6xf32> to memref<vector<3x6xf32>>
%122 = memref.load %121[] : memref<vector<3x6xf32>>
%123 = vector.type_cast %6 : memref<3xvector<6xf32>> to memref<vector<3x6xf32>>
memref.store %122, %123[] : memref<vector<3x6xf32>>
....
Issue here is that memref<3x6xf32>
contains continuous sequence of floats, while memref<vector<3x6xf32>>
is lowered into [3 x <6 x float>]
LLVM-IR type and LLVM vector is padded to 8 (nearest power of 2).
memref.load %121[] : memref<vector<3x6xf32>>
is lowered into load [3 x <6 x float>], [3 x <6 x float>]* %702
, which loads floats wrongly (data layouts under memref<3x6xf32>
and memref<vector<3x6xf32>>
are different).
This is LLVM-IR of first basic block of this scf.if %120
.
%695 = extractvalue { float*, float*, i64, [2 x i64], [2 x i64] } %108, 0
%696 = bitcast float* %695 to [3 x <6 x float>]*
%697 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } undef, [3 x <6 x float>]* %696, 0
%698 = extractvalue { float*, float*, i64, [2 x i64], [2 x i64] } %108, 1
%699 = bitcast float* %698 to [3 x <6 x float>]*
%700 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %697, [3 x <6 x float>]* %699, 1
%701 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %700, i64 0, 2
%702 = extractvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %701, 1
%703 = load [3 x <6 x float>], [3 x <6 x float>]* %702, align 32
%704 = extractvalue { <6 x float>*, <6 x float>*, i64, [1 x i64], [1 x i64] } %84, 0
%705 = bitcast <6 x float>* %704 to [3 x <6 x float>]*
%706 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } undef, [3 x <6 x float>]* %705, 0
%707 = extractvalue { <6 x float>*, <6 x float>*, i64, [1 x i64], [1 x i64] } %84, 1
%708 = bitcast <6 x float>* %707 to [3 x <6 x float>]*
%709 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %706, [3 x <6 x float>]* %708, 1
%710 = insertvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %709, i64 0, 2
%711 = extractvalue { [3 x <6 x float>]*, [3 x <6 x float>]*, i64 } %710, 1
store [3 x <6 x float>] %703, [3 x <6 x float>]* %711, align 32
And this is related x86-64
assembly. (%rdi)
is address of source memref and (%rax)
is address of destination memref (address offsets should be different to properly load floats (0x20
should be 0x18
, 0x30
should be 0x28
, e.t.c.)):
vmovaps (%rdi),%xmm0
mov 0x10(%rdi),%rcx
vmovaps 0x20(%rdi),%xmm1
mov 0x30(%rdi),%rdx
vmovaps 0x40(%rdi),%xmm2
mov 0x50(%rdi),%rsi
movl $0x0,(%rdi)
mov %rsi,0x50(%rax)
vmovaps %xmm2,0x40(%rax)
mov %rdx,0x30(%rax)
vmovaps %xmm1,0x20(%rax)
mov %rcx,0x10(%rax)
vmovaps %xmm0,(%rax)
What is proper way to solve this issue - fix VectorTransferFullPartialRewriter
pattern or fix vector.type_cast
lowering?