On Improving Arm SME Lowering Resilience in MLIR

Hi, I got started with this, I am looking into options to drop unit dimensions. There are a few patterns with similar ideas as part of populateFlattenVectorTransferPatterns. Most of them are limited to trailing / leading dimensions. Is there a reason for such a limitation ?

Can I extend it to broadcast saying those are equivalent ?

%lhsBcast = vector.broadcast %lhsCast : vector<[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>

//////// Equivalent to

%lhssc = vector.shape_cast %lhsCast : vector<[4]x1xf32> to vector<[4]xf32>
%lhsBcast = vector.broadcast %lhssc: vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsOut = vector.shape_cast %lhsBcast : vector<[4]x[4]xf32> to vector<[4]x[4]x1f32>

%rhssc = vector.shape_cast %rhsCast : vector<1x[4]xf32> to vector<[4]xf32>
%rhsBcast = vector.broadcast %rhssc: vector<[4]xf32> to vector<[4]x[4]xf32>
%rhssc = vector.shape_cast %rhsBcast : vector<[4]x[4]xf32> to vector<[4]x1x[4]xf32>

I am concerned about the rhs case, when the input leading unit dimension ends in the middle of the result.

ConvertIllegalShapeCastOpsToTransposes shows limitations of shapeCasts in scalable vectors. Do you see illegality in doing such a transform for transpose too ?


%lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]x1xf32>

%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>

//////// Equivalent to

%lhssc = vector.shape_cast %lhsBcast : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhssc, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%lhsOut = vector.shape_cast %lhsT : vector<[4]x[4]xf32> to vector<[4]x[4]x1f32>

%rhssc = vector.shape_cast %rhsBcast : vector<[4]x1x[4]xf32> to vector<[4]x[4]xf32>
// transpose has no effect and can be folded away.
%rhs = vector.transpose %rhssc, [0, 1] : vector<[4]x[4]xf32> to vector<[4]x[4]x1xf32>
%rhssc = vector.shape_cast %rhs : vector<[4]x[4]xf32> to vector<[4]x1x[4]xf32>

Solution proposal :

If you don’t see any legality issue, I’d like to propose

  • For standard ops, have a templated Pattern template <typename ConcreteOp> struct DropUnitDimFromOp : public OpRewritePattern<ConcreteOp> quite similarly to DropUnitDimFromElementwiseOps will drop trailing / leading (or all?) unit dims on all operands and result and generate ShapeCasts accordingly.
  • For transpose ops, add a DropUnitDimFromTransposeOp rewritePattern which will additionally take care of dropping the dimension in the transposition map.
  • Add those DropUnitDimFromOp <BroadcastOp> and DropUnitDimFromTransposeOp to populateDropUnitDimWithShapeCastPatterns