I took a quick look at this, I don’t think it’s too bad (if you can get rid of some unit dims).
Here’s the ‘canonical’ form of a masked outerproduct (via multi-reduction):
Canonical form
%lhsCast = vector.shape_cast %inputLHS : vector<[4]xf32> to vector<[4]x1xf32>
%lhsBcast = vector.broadcast %lhsCast : vector<[4]x1xf32> to vector<[4]x[4]x1xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhsCast = vector.shape_cast %inputRHS : vector<[4]xf32> to vector<1x[4]xf32>
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>
%mul = arith.mulf %lhsT, %rhs : vector<[4]x[4]x1xf32>
%tileMask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
%dropDim = vector.shape_cast %mul : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32>
%addAcc = arith.addf %acc, %dropDim : vector<[4]x[4]xf32>
%applyMask = arith.select %tileMask, %acc, %addAcc : vector<[4]x[4]xi1>, vector<[4]x[4]xf32>
Not sure why it shape casts to add the unit dims (does broadcast need that?), also results in a (semi-pointless, moves unit dim only) transpose for the RHS too.
If we got rid of all the unit dims, I don’t think this is too bad:
No unit dims
%lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32>
%tileMask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
%addAcc = arith.addf %acc, %mul : vector<[4]x[4]xf32>
%applyMask = arith.select %tileMask, %acc, %addAcc : vector<[4]x[4]xi1>, vector<[4]x[4]xf32>
Then this can be lowered with two fairly easy patterns:
Step 1
%lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32>
This can be rewritten as:
%mul = arm_sme.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
Step 2
%mul = arm_sme.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
%addAcc = arith.addf %acc, %mul : vector<[4]x[4]xf32>
%applyMask = arith.select %tileMask, %acc, %addAcc : vector<[4]x[4]xi1>, vector<[4]x[4]xf32>
This can be rewritten as:
%lhsMask = vector.create_mask %lhsDim : vector<[4]xf32>
%rhsMask = vector.create_mask %rhsDim : vector<[4]xf32>
%mul = arm_sme.outerproduct $lhs, $rhs acc($acc) masks($lhsMask, $rhsMask) : vector<[4]xf32>, vector<[4]xf32>