I’m trying to figure out the best way to generate masks for scalable vectors. Right now, something like:
%0 = vector.create_mask %n : vector<4xi1>
%cst = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> %0 = arith.index_cast %n : index to i32 %1 = splat %0 : vector<4xi32> %2 = arith.cmpi slt, %cst, %1 : vector<4xi32>
Which is very much vector-length-dependent. I’ve created a patch ([mlir][Vector] Enable create_mask for scalable vectors) that creates an alternative lowering when the mask needs to be scalable based on the llvm.get.active.lane.mask.* intrinsics:
%0 = llvm.mlir.constant(0 : i32) : i32 %1 = arith.index_cast %n : index to i32 %2 = llvm.intr.get.active.lane.mask %0, %1 : i32, i32 to vector<xi1>
I’m finding a couple of issues with this approach, and I’d like some feedback and/or ideas.
First, I’m creating a dependency between Vector and LLVMIR. All the other operations that depend on LLVMIR live in VectorToLLVM conversions, which makes sense. I see two solutions to this problem:
- Move VectorCreateMaskOpConversion with all the other lowerings that target LLVM specifically. I don’t quite like that because, as it is, this conversion doesn’t need to depend on LLVMIR for fixed-length vectors. Now, I’m not sure if there is any current or planned consumer of this behaviour.
- Create a different lowering for CreateMaskOp in VectorToLLVM to address scalable vectors specifically.
The second issue is a bit more problematic. The current lowering and the scalable-compatible lowering do not share the exact same semantics. If the index of the operation is negative, get.active.lane.mask wraps around, while the current lowering clamps to 0. So
vector.create_mask -1 will lower to all 0s for fixed-length vectors and all 1s for scalable vectors. Not sure this is necessarily a problem in practice, but it’s definitely there.
One way I might be able to avoid the inconsistency would be using llvm.experimental.stepvector and using a similar approach to the fixed-length vector, e.g.:
%cst = llvm.intr.experimental.stepvector : vector<xi32> %0 = arith.index_cast %n : index to i32 %1 = splat %0 : vector<xi32> %2 = arith.cmpi slt, %cst, %1 : vector<xi32>
But it’s an experimental intrinsic right now, it’s not even in LLVM Dialect, so I’d rather avoid this.
Any ideas or suggestions are greatly appreciated.