Introduction & Motivation
This pass is proposed so that we can use tile vector
and vectorize for loop
according to the hardware SIMD instructions. tile vector
means that the vector don’t care about the hardware information like how big a vector can be stored in a cpu register
. We directly operate on the tile vector
due to some passes do not need to care about hardware information which can reduce the difficulty of pass transformation. Meanwhile, some passes want the tile vector
to fit within L1 cache size due to a better performance configuration. After the tile vector
pass, we will vectorize the inner most for loop
according to the data length supported by SIMD instructions which we call physical register vector pass.
Why we need this
-
There is no pass to do this now.
-
This pass allows us to control the actual
scf.for
loop transformation and vectorization optimization instead of just relying on LLVM’s vectorization mechanism. It can help us better control the optimization offor loop
.
Relate Works
tile vector
abstraction does not exist in current MLIR Community. And mainly three functions or passes can do for loop
vectorization.
Tile Vector
Convert operation which operates on tensor to vector. linalg operation vectorization function can convert operation operate on tensor to vector. But we may need to support more operations like tensor::BitcastOp
, tensor::ExpandShapeOp
, tensor::ConcatOp
etc. And this should become an independent pass to convert the operation on the tensor into the operation on the vector.
Loop Vectorization
Affine
affine
dialect implements a super-vectorization class to do vectorization. See SuperVectorize.cpp for more details.
Methods in affine
only transform the affine.for
and memref operations. But we need transform the scf.for
and operations operate on vector now (like arith/math). So we can refer to its transformation ideas, but it cannot be directly used for the pass we designed. At the same time, their approach does focus on the vectorization of scalar data in affine
, while we focus on how to maximize the use of SIMD support for operations in scf.for
loop body.
SCF
scf
mainly use tile
to do vectorization like tileUsingSCF. We may need to use this function to transform the scf.for
bound
and step
. But this function only supports the operation which has TilingInterface
. Almost all operations in vector
dialect do not have this interface.
Another ralated work is vectorToSCF. But this work only lowering transfer_read
and transfer_write
.
Vector Distribution For GPU
This is the first Vector Distribution RFC. But this work only works for gpu
now.
Implementation
Lower to tile vector pass (This abstraction layer can reserve the size of the upper-level tile tensor due to to different hardware performance (such as L1 cache size for cpu))
|
|
Lower to physical register pass (Vectorize the inner most `for loop` according to the data length supported by SIMD instructions on CPU.)
Convert tensor to tile vector
pass
We convert the tensor
type to vector
. Mainly use vector.transfer_read
and vector.transfer_write
etc.
Example:
func.func @add_tensor_test0(%arg0: tensor<4x8x1024xf32>, %arg1: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> {
%0 = tensor.empty() : tensor<4x8x1024xf32>
%1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x1024xf32>, tensor<4x8x1024xf32>) outs(%0: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32>
return %1 : tensor<4x8x1024xf32>
}
After tile vector
pass:
func.func @add_tensor_test0(%arg0: tensor<4x8x1024xf32>, %arg1: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<4x8x1024xf32>
%1 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : tensor<4x8x1024xf32>, vector<4x8x1024xf32>
%2 = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : tensor<4x8x1024xf32>, vector<4x8x1024xf32>
%3 = arith.addf %1, %2 : vector<4x8x1024xf32>
%4 = vector.transfer_write %3, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<4x8x1024xf32>, tensor<4x8x1024xf32>
return %4 : tensor<4x8x1024xf32>
}
Physical register vector pass
Example:
func.func @add_tensor_test0(%arg0: tensor<4x8x1024xf32>, %arg1: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x8x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %0) -> (tensor<4x8x1024xf32>) {
%2 = scf.for %arg4 = %c0 to %c8 step %c1 iter_args(%arg5 = %arg3) -> (tensor<4x8x1024xf32>) {
%c16 = arith.constant 16 : index
%c1024 = arith.constant 1024 : index
%3 = scf.for %arg6 = %c0 to %c1024 step %c16 iter_args(%arg7 = %arg5) -> (tensor<4x8x1024xf32>) {
%4 = vector.transfer_read %arg0[%arg2, %arg4, %arg6], %cst {in_bounds = [true]} : tensor<4x8x1024xf32>, vector<16xf32>
%5 = vector.transfer_read %arg1[%arg2, %arg4, %arg6], %cst {in_bounds = [true]} : tensor<4x8x1024xf32>, vector<16xf32>
%6 = arith.addf %4, %5 : vector<16xf32>
%7 = vector.transfer_write %6, %arg7[%arg2, %arg4, %arg6] {in_bounds = [true]} : vector<16xf32>, tensor<4x8x1024xf32>
scf.yield %7 : tensor<4x8x1024xf32>
}
scf.yield %3 : tensor<4x8x1024xf32>
}
scf.yield %2 : tensor<4x8x1024xf32>
}
return %1 : tensor<4x8x1024xf32>
}
- We need to get how many current values can the hardware’s SIMD instructions operate on. Then we will have the legal length of the hardware’s SIMD instructions to vectorize
for loop
. - Convert the indexing map for current
for loop
forms. We also need to mask operation properly according to the indexing map. - In order to avoid data dependency issue, we need to classify all operations. For example,
transpose
operation may can’t fuse a common outer for loop with other elementwise operation. For example:
// assume we have vector = [100x100x100xf32]
scf.for 0...100:
A[100, 100] = arith.add
B[100, 100] = vector.transpose A permutation[1, 0]
C[100, 100] = arith.add A, B
// if we generate a common outer for loop, that will be:
scf.for 0...100:
scf.for 0...100:
scf.for 0...100, step = 16:
A_vector[16xf32] = add
B_vector[16, 1] = transpose A_vector
// the result will be wrong: C_vector should add B_vector[1, 16], but B_vector is [16,1] here. And [16,1] is also a invalid IR in here.
C_slice[1, 16] = add A_slice, B_slice
// but it should be:
scf.for 0...100:
scf.for 0...100:
scf.for 0...100, step = 16:
A_vector[16xf32] = add
scf.for 0...100:
scf.for 0...100:
B_vector[16x16] = vector.transpose A_vector[16x16] permutation [1,0]
scf.for 0...100:
scf.for 0...100, step = 16:
C_slice[16] = add A_vector, B_vector
So, we divide operations into 3 categories to check whether can generate a common outer for loop
between different operations:
- Complex operations: reorder, transpose, typecast and other complex operations. There may be data dependencies between such operations. It is necessary to determine whether the dimension of the current operation is a data-dependent dimension. If it is not a data-dependent dimension, then the operation can reuse the current for loop.
- Elementwise operations: add, sub, mul, div, max, min and etc. A
for loop
can be reused between such operations. - Other operations.
- Reduce
- It is necessary to determine whether the previous for loop can be reused based on the reduce axis.
- Broadcast
- It is necessary to determine whether the previous for loop can be reused based on the broadcast axis.
- Reduce
At the same time, in order to achieve the best performance, during the phase of generating the loop body of the operation, when tile sizes of two operations are inconsistent, an operation with a smaller tile size may need to do `loop unroll’. For example:
arith.add : vector<32xf32>
arith.add : vector<32xu8>
==> No loop unroll form
scf.for 0...32 step = 16:
arith.add : vector<16xf32>
scf.for 0...32 step = 32:
arith.add : vector<32xu8>
==> loop unroll, those two operations can fuse a common for loop.
scf.for 0...32 step = 32:
arith.add : vector<16xf32>
arith.add : vector<16xf32>
arith.add : vector<32xu8>
What features will be provided
- As mentioned above, simple vector-based operation fusion. Operations without data dependencies can use the same for loop.
- Perform vectorization to adapt to hardware for operations without data dependencies.
Users can choose to use the fusion we provide to vectorize the current operations, or provide a set of operations without data dependencies and directly call the method provided in 2 to perform vectorization.
Thank you for your attention. Feel free to join the discussion for better design or code implementation.