[RFC] Adding AllRanksMatchIfKnown trait to TOSA broadcastable Operators

This RFC proposes adding a trait, AllRanksMatchIfKnown, to TOSA broadcastable operators.

The trait AllRanksMatchIfKnown specifies that all operands and results must have matching ranks except for those that are unranked (e.g., with type *xFP32)

Motivation:

The TOSA broadcast requires matching ranks for operands and results whereas the verifier for the ResultsBroadcastableShape trait allows for mismatched ranks.

Until now, a broadcastable TOSA operator may be created with operands of mismatching ranks and rely on the pass TosaMakeBroadcastable to insert Reshape operations to equalize the ranks of these operands. But there is no explicit pass that verifies that broadcastable operators (eventually) have operands and results with matching ranks. For example, any TOSA operations created after running TosaMakeBroadcastable pass(es) can violate TOSA broadcast requirements without being caught.

This has been reported as an issue in: [mlir][TOSA] TOSA MLIR dialect does not properly verify broadcast compatibility · Issue #61822 · llvm/llvm-project (github.com)

Proposal:

We propose adding a trait to require, on construction, that broadcastable TOSA operators have operands and results with matching ranks. This will put some burden on legalization code but has the advantage of not having to worry about possible rank mismatches in transforms.

We tried using the existing trait AllRanksMatch, but found out that this trait does not work because TOSA allows unranked operands and results whose shapes and ranks will be determined by shape propagation across operations. For example:

// CHECK-LABEL: @test_multiple
func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
  // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
  // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32>
  %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32>
 
  // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
  %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
  return %0 : tensor<*xf32>
}

So, we propose adding a TOSA specific trait, AllRanksMatchIfKnown, which will require operands and results to have matching ranks except for unranked operands/results.

We will add this trait to all broadcastable TOSA operators.

As unranked tensors are re-shaped by the pass TosaInferShapes, we will invoke AllRanksMatchIfKnown verifier checks explicitly to double check that the operands/results have matching ranks.

This will also render the TosaMakeBroadcastable pass obsolete and we will remove this pass later on.

Existing tf/tfl legalization

We will fix up existing tf/tfl legalization and tests to pass the verifier for the new trait.

Helper Function EqualizeRanks

We have refactored the helper function EqualizeRanks out of the TosaMakeBroadcastable pass into mlir/Dialect/Tosa/Utils/ConversionUtils.h:

/// Common code to create the reshape op where necessary to make the rank of two
/// values equal. input1 and input2 will be updated when the rank has
/// changed. The caller is expected to use these to rewrite the original
/// operator with the RESHAPE now in the graph.
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
                            Value &input1, Value &input2);

This can be useful in helping to construct valid broadcastable TOSA operations. For example, this code snippet can have mismatched operand ranks before:

rewriter.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias);

Adding an EqualizeRank call ensures matching ranks (the “outputValue” and “bias” are Value variables):

if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
  return failure();
}
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias);

So, we propose adding a TOSA specific trait

Can this be a generic MLIR trait, if there are no complaints? My out-of-tree dialect may have some uses for this trait.

thanks for the feedback. Our current thinking is that it may be best if we add this trait as tosa specific for now, and then separately see if there might be more requests to promote it to be generic trait?

Hey Tai,

This sounds good. I don’t see anything TOSA specific in it though and seems like it would fit in generality of the others like SameOperandAndResultTypes. That sounds good to add there. Helper function could be TOSA specific for now (as it probably inserts TOSA ops as needed). Unless I’m missing something.

Thanks

Hi!

Great to hear that verifiers are coming for this. I implemented something very similar in ONNX-MLIR. I would love to replace it with common infrastructure coming from upstream!

Also a +1 from me for having a generic MLIR trait. Might be useful for other folks, too.

Thanks to encouragements above, I have put up a mlir patch for review:
:gear: D156369 [mlir] Add trait AllRanksMatchIfKnown (llvm.org)
which adds a native op trait, AllRanksMatchIfKnown, and associated verifier and tests, to mlir