This RFC proposes adding a trait, AllRanksMatchIfKnown, to TOSA broadcastable operators.
The trait AllRanksMatchIfKnown specifies that all operands and results must have matching ranks except for those that are unranked (e.g., with type *xFP32)
Motivation:
The TOSA broadcast requires matching ranks for operands and results whereas the verifier for the ResultsBroadcastableShape trait allows for mismatched ranks.
Until now, a broadcastable TOSA operator may be created with operands of mismatching ranks and rely on the pass TosaMakeBroadcastable to insert Reshape operations to equalize the ranks of these operands. But there is no explicit pass that verifies that broadcastable operators (eventually) have operands and results with matching ranks. For example, any TOSA operations created after running TosaMakeBroadcastable pass(es) can violate TOSA broadcast requirements without being caught.
This has been reported as an issue in: [mlir][TOSA] TOSA MLIR dialect does not properly verify broadcast compatibility · Issue #61822 · llvm/llvm-project (github.com)
Proposal:
We propose adding a trait to require, on construction, that broadcastable TOSA operators have operands and results with matching ranks. This will put some burden on legalization code but has the advantage of not having to worry about possible rank mismatches in transforms.
We tried using the existing trait AllRanksMatch, but found out that this trait does not work because TOSA allows unranked operands and results whose shapes and ranks will be determined by shape propagation across operations. For example:
// CHECK-LABEL: @test_multiple
func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
// CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32>
%1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
%2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
So, we propose adding a TOSA specific trait, AllRanksMatchIfKnown, which will require operands and results to have matching ranks except for unranked operands/results.
We will add this trait to all broadcastable TOSA operators.
As unranked tensors are re-shaped by the pass TosaInferShapes, we will invoke AllRanksMatchIfKnown verifier checks explicitly to double check that the operands/results have matching ranks.
This will also render the TosaMakeBroadcastable pass obsolete and we will remove this pass later on.
Existing tf/tfl legalization
We will fix up existing tf/tfl legalization and tests to pass the verifier for the new trait.
Helper Function EqualizeRanks
We have refactored the helper function EqualizeRanks out of the TosaMakeBroadcastable pass into mlir/Dialect/Tosa/Utils/ConversionUtils.h:
/// Common code to create the reshape op where necessary to make the rank of two
/// values equal. input1 and input2 will be updated when the rank has
/// changed. The caller is expected to use these to rewrite the original
/// operator with the RESHAPE now in the graph.
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2);
This can be useful in helping to construct valid broadcastable TOSA operations. For example, this code snippet can have mismatched operand ranks before:
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias);
Adding an EqualizeRank call ensures matching ranks (the “outputValue” and “bias” are Value variables):
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
return failure();
}
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias);