(Authors: Diego Caballero & Nicolas Vasilache)
This proposal is aimed at introducing generic support for vector masking in MLIR. Having a generic way to represent masked operations will pave the way for representing and optimizing more complex vector scenarios. This proposal is also a step towards filling the vector representation gap between MLIR and LLVM. We plan to extend this approach to also support vector operations with a dynamic vector length in the future.
Background and Motivation
Dealing with out-of-bounds memory accesses, side-effecting instructions and control flow that may diverge within the lanes of a vector is one of the biggest challenges of vectorization. Most recent SIMD and vector architectures address this problem by introducing dedicated hardware support to selectively enable/disable the execution of each vector lane using masks. However, compiler support is needed to effectively take advantage of these hardware features.
Modeling masked operations in LLVM has been a hairy topic. The community attempted to model masked computation several times and discussions on the matter went on for years. Luckily, the Vector Predication proposal gained support and landed recently. This approach separates the concern of masking from the rest of the IR by introducing a dedicated set of vector operations (~dialect?
) that can take both a vector mask and a dynamic vector length as arguments. The approach presented in this RFC relies on the Vector Predication proposal as a solid foundation to build a hardware-independent vector/predication model in MLIR.
Nonetheless, modeling masking/predication in MLIR is even more challenging than in LLVM. MLIR has a growing and unbound IR with multiple levels of abstraction. Most of these levels are vectorizable and, therefore, subject to be masked. Creating a dedicated set of operations with support for masking/dynamic vector lengths for any potential maskable operation would lead to a prohibitive level of duplication. On the other hand, introducing native mask/vector length arguments to every maskable operation would be extremely invasive and leak vectorization details to each and every vectorizable operation in MLIR.
The lack of a common vector layer with masking and dynamic vector length support is also a concern. Modeling these vector concepts in a generic way will allow us to share implementations across targets with similar vector/predication features (e.g., SVE, RVV), making some target-specific dialects unnecessary.
We propose a generic and less invasive approach to masking in MLIR that properly scales with the unbound and multi-abstraction level nature of this intermediate representation.
Proposal
This proposal comprises three parts: i) a new vector.mask operation with a region to model masked computation, ii) new interfaces to characterize operations that can mask and be masked and, iii) a set of masked elementary vector operations that will serve as a low-level landing pad for masked operations and facilitate the conversion to LLVM and other backends with masking support.
The vector.mask Operation
We propose the vector.mask operation as a generic vehicle to apply masking to any operation in MLIR, regardless of its level of abstraction. The following snippet shows how an arith.addf can be masked using this approach:
%0 = vector.mask %mask, %passthru : vector<2x8xi1> -> vector<2x8xf32> {
// '%a' and '%b' are captured from above.
%1 = arith.addf %a, %b : vector<2x8xf32>
vector.yield %1 : vector<2x8xf32>
}
The vector.mask operation takes a vector mask and an optional pass-through vector as arguments. A vector.yield-terminated region encloses the operation (nested operation) to be masked with the provided arguments. Other values used within the region are captured from above. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the nested operation and which ones should not. The vector.mask operation returns the value produced by the masked execution of the nested operation, if any. The masked-off lanes in the result vector are taken from the corresponding lanes of the pass-through value, if provided, or left unmodified, otherwise.
The vector.mask operation does not prescribe how a specific operation should be masked or how a masked operation should be lowered. Masking requirements will be provided by each maskable operation through the MaskableOp interface (see next section). Lowering of masked operations is implementation defined. Multiple lowering strategies for masked operations are expected.
Alternative Assembly Format
The following snippet shows a more compacted assembly format for the vector.mask operation that we implemented in our prototype. We think that this alternative could be further improved by removing more redundant tokens to make it even shorter. We will use this textual form in this RFC moving forward.
%0 = vector.mask %mask, %passthu { %1 = arith.addf %arg3, %arg4 : vector<2x8xf32> } : vector<2x8xi1>, vector<2x8xf32>
Example
Let’s take a look at how vector.mask can be used to vectorize the following dynamically-shaped linalg.matmul and how the mask can be propagated through progressive lowering and levels of abstraction.
func.func @masked_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
We assume that linalg.matmul is vectorized to a vector.contract operation with vector sizes {2, 4, 8} for dimensions i, j and k, respectively. We also assume that the dynamic sizes of dimensions i, j and k are bounded by 2, 4 and 8, respectively. This can be achieved by tiling (not shown in this example for the sake of simplicity). The resulting masked vector code is as follows:
func.func @masked_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = vector.create_mask %0, %2 : vector<2x8xi1>
%4 = vector.mask %3 { %12 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x8xf32> } : vector<2x8xi1>, vector<2x8xf32>
%5 = vector.create_mask %2, %1 : vector<8x4xi1>
%6 = vector.mask %5 { %12 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x4xf32> } : vector<8x4xi1>, vector<8x4xf32>
%7 = vector.create_mask %0, %1 : vector<2x4xi1>
%8 = vector.mask %7 { %12 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32> } : vector<2x4xi1>, vector<2x4xf32>
%9 = vector.create_mask %0, %1, %2 : vector<2x4x8xi1>
%10 = vector.mask %9 { %12 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %6, %8 : vector<2x8xf32>, vector<8x4xf32> into vector<2x4xf32> } : vector<2x4x8xi1>, vector<2x4xf32>
%11 = vector.mask %7 { %12 = vector.transfer_write %10, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32> } : vector<2x4xi1>, tensor<?x?xf32>
return %11 : tensor<?x?xf32>
}
Note that the vector.contract operation is masked with a canonical mask that comprises all the vector dimensions of the iteration space. The masked vector.contract is then lowered to masked lower level operations, where the canonical mask is decomposed into multiple masks, depending on how the canonical vector iteration space is projected/permuted for each lower level operation:
func.func @masked_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = vector.create_mask %0, %2 : vector<2x8xi1>
%4 = vector.mask %3 { %13 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (d0, 0, d1)>} : tensor<?x?xf32>, vector<2x4x8xf32> } : vector<2x8xi1>, vector<2x4x8xf32>
%5 = vector.create_mask %2, %1 : vector<8x4xi1>
%6 = vector.mask %5 { %13 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>} : tensor<?x?xf32>, vector<2x4x8xf32> } : vector<8x4xi1>, vector<2x4x8xf32>
%7 = vector.create_mask %0, %1 : vector<2x4xi1>
%8 = vector.mask %7 { %13 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32> } : vector<2x4xi1>, vector<2x4xf32>
%9 = vector.create_mask %0, %1, %2 : vector<2x4x8xi1>
%10 = vector.mask %9 { %13 = arith.mulf %4, %6 : vector<2x4x8xf32> } : vector<2x4x8xi1>, vector<2x4x8xf32>
%11 = vector.mask %9 { %13 = vector.multi_reduction <add>, %10, %8 [2] : vector<2x4x8xf32> to vector<2x4xf32> } : vector<2x4x8xi1>, vector<2x4xf32>
%12 = vector.mask %7 { %13 = vector.transfer_write %11, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32> } : vector<2x4xi1>, tensor<?x?xf32>
return %12 : tensor<?x?xf32>
}
The resulting code could then be lowered further and eventually be converted to LLVM using the Vector Predication intrinsics.
This example demonstrates how the masking process throughout the compiler pipeline is independent from the actual operations being masked, separating the concerns and preventing an explosion of new vector/masked operations throughout MLIR. Even existing operations with native masking or close semantics, such as vector.transfer_read, vector.transfer_write and vector.execute_on_warp_0 could be simplified in the future by deferring masking semantics to the new vector.mask operation.
The MaskableOp and MaskingOp Interfaces
We propose two new interfaces. The MaskableOp interface defines operations that can be masked. This will help vectorization algorithms to rule out operations where masking wouldn’t make sense, such as scf.if, scf.for, affine.if, affine.for, func.func, etc. The interface will also provide methods to determine if an operation is masked or to know how a canonical mask should be projected/permuted to mask that operation. Further details will be decided at implementation time.
The MaskingOp interface defines operations that can mask other operations (e.g., vector.mask) or operations with native support for masking (e.g., vector.transfer_read, vector.transfer_write). This interface will help with the incremental transition of masking responsibilities from operations with native masking support to the vector.mask operation. It will also help reduce dependencies between dialects with ops implementing the MaskableOp interface and dialects with ops implementing the MaskingOp interface (e.g., the Vector dialect).
Vector Predication Operations
In the long term, we plan to introduce a new set of low-level operations with a native mask (and vector length) argument. These operations will serve as a low level landing pad for higher level operations masked with vector.mask and will align MLIR with the Vector Predication intrinsics in LLVM. We don’t plan to reinvent the wheel here other than adjusting the Vector Predication instructions to the vector ecosystem in MLIR. We don’t plan to work on this in the short term either, so contributions are welcome! ![]()
Op Naming and Dialect Logistics
Currently, the vector.mask operation provides an abstraction for masking. However, this operation should be extended in the future to also model dynamic vector length information. In that context, the mask name might not be the most appropriate and we should consider predicate or a more generic term.
The vector.mask operation has been prototyped in the Vector dialect. However, we also plan to introduce Vector Predication operations, which are quite numerous. An alternative location for all the masking/predication related operations could be a new Predication/Masking dialect. That would keep vectorization and predication/masking concerns separate. The Predication/Masking dialect would be an optional extension of the Vector dialect, which won’t be loaded by targets without predication/masking support.
Suggestions are welcome!
Future Work
- Extend proposal to support scalable/dynamic vector length vectorization.
- Implement full-masking lowering.
- Extend masking to tensor land.
@nicolasvasilache, @zhanghb97, @javiersetoain, @giuseros, @MaheshRavishankar, @ThomasRaoux