It looks like you are decomposing the vector.transfer_read into simpler operations. You should run these patterns (or the pass on top of them) to simplify the vector.mask before decomposing vector.transfer_* ops.
If
canonicalizeis run betweenscalable vectorizeandreductionToContract, the maskedvector.multi_reduction <add>on 1D is converted into anarith.add, and FMOPA are not generated. The issue with this is potential missed optimizations since, for instance, in the matmul use case, we have to runbufferizebefore lowering to outer products.
Hey Hugo,
Great to see this used and there’s lots of good points here so thanks for creating the thread.
I revisited where we bufferize in matmul.mlir and bufferizing after lowering to outer products (but before -test-lower-to-arm-sme) does now work FWIW. As the test mentions when this was originally added bufferization was necessary before then for TransferReadDropUnitDimsPattern to kick in but that seems to no longer be the case.
Although I don’t see the connection between the multi_reduction canonicalization and the missed optimizations as a result of bufferization (?).
A bit unrelated but whilst looking into this I’ve just discovered the ArmSME pipeline still lowers after that canonicalization which is surprising given we have no support to lower those Arith ops. It appears they become dead after -convert-arm-sme-to-llvm and get incorrectly DCE’d. We’ll look into fixing that!
I don’t think we have any examples upstream… The simple example in this scenario could be, reusing your example from above to go from
func.func @matmul_transpose_a(
%arg0: tensor<?x?xf32>,
%arg1: tensor<?x?xf32>) {
// unsimplified code
}
to something like
func.func @matmul_transpose_a(
%arg0: tensor<?x?xf32>,
%arg1: tensor<?x?xf32>) {
%dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%rem = affine_apply<(d0)->(d0 - d0 floordiv 8)>(%dim_1)
%c0 = arith.constant 0 : index
%is_divisible_by_8 = arith.cmpi eq %rem, %c0 : index
scf.if %is_divisible_by_8 {
// here we can now assume %dim_1 is divisible by 8
// and trigger simplifications described in my post above
} else {
// unsimplified code for the general case
}
automatically. We will compile 2+ versions of the code and trade-off binary size and some additional dispatch for better-optimized versions in the happy case.
Hi, I got started with this, I am looking into options to drop unit dimensions. There are a few patterns with similar ideas as part of populateFlattenVectorTransferPatterns. Most of them are limited to trailing / leading dimensions. Is there a reason for such a limitation ?
Can I extend it to broadcast saying those are equivalent ?
%lhsBcast = vector.broadcast %lhsCast : vector<[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
//////// Equivalent to
%lhssc = vector.shape_cast %lhsCast : vector<[4]x1xf32> to vector<[4]xf32>
%lhsBcast = vector.broadcast %lhssc: vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsOut = vector.shape_cast %lhsBcast : vector<[4]x[4]xf32> to vector<[4]x[4]x1f32>
%rhssc = vector.shape_cast %rhsCast : vector<1x[4]xf32> to vector<[4]xf32>
%rhsBcast = vector.broadcast %rhssc: vector<[4]xf32> to vector<[4]x[4]xf32>
%rhssc = vector.shape_cast %rhsBcast : vector<[4]x[4]xf32> to vector<[4]x1x[4]xf32>
I am concerned about the rhs case, when the input leading unit dimension ends in the middle of the result.
ConvertIllegalShapeCastOpsToTransposes shows limitations of shapeCasts in scalable vectors. Do you see illegality in doing such a transform for transpose too ?
%lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>
//////// Equivalent to
%lhssc = vector.shape_cast %lhsBcast : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhssc, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%lhsOut = vector.shape_cast %lhsT : vector<[4]x[4]xf32> to vector<[4]x[4]x1f32>
%rhssc = vector.shape_cast %rhsBcast : vector<[4]x1x[4]xf32> to vector<[4]x[4]xf32>
// transpose has no effect and can be folded away.
%rhs = vector.transpose %rhssc, [0, 1] : vector<[4]x[4]xf32> to vector<[4]x[4]x1xf32>
%rhssc = vector.shape_cast %rhs : vector<[4]x[4]xf32> to vector<[4]x1x[4]xf32>
Solution proposal :
If you don’t see any legality issue, I’d like to propose
- For standard ops, have a templated Pattern
template <typename ConcreteOp> struct DropUnitDimFromOp : public OpRewritePattern<ConcreteOp>quite similarly toDropUnitDimFromElementwiseOpswill drop trailing / leading (or all?) unit dims on all operands and result and generate ShapeCasts accordingly. - For transpose ops, add a
DropUnitDimFromTransposeOprewritePattern which will additionally take care of dropping the dimension in the transposition map. - Add those
DropUnitDimFromOp <BroadcastOp>andDropUnitDimFromTransposeOptopopulateDropUnitDimWithShapeCastPatterns
All these shape casts would be ‘illegal’ if they still existed in the final IR, but the goal should be to get them to cancel out (same with the extra transpose). The only operations remaining should be those that can be converted to a vector or ArmSME outer product operation. I think that could work with your proposed solution (another similar idea would be to bubble up the shape_casts until they cancel out, which I think would happen if you started at the shape_cast before the arith.addf).
Possibly just nobody needed it before? Your example with the RHS seems correct to me.
Note: I think you should try to rewrite this to a vector.outerproduct (which can then be lowered via the ArmSME pipeline, rather than going directly to arm_sme.outerproduct, which is more restrictive wrt to vector sizes).
Also, I wonder if would be simpler to update ElideUnitDimsInMultiDimReduction (which is what results in this lowering) to simply emit a vector.outerproduct in this case (rather than lowering it to arith), which would seem more canonical to me
Edit: Ah the multi_reduction is just:
vector.mask %16 { vector.multi_reduction <add>, %15, %14 [2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>
The broadcasts/transposes are done via the transfer_reads (from transfer_permutation_patterns).