In previous discussion in [RFC] Moving TOSA construction lib into MLIR core - MLIR - LLVM Discussion Forums, we proposed a progressive move of the TOSA dialect construction utility functions into the MLIR core:
These functions are currently sitting as identical copies in two frontend places:
TensorFlow: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tosa/transforms
legalize_common. and legalize_utils.**
Torch-MLIR:
https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToTosa
TosaLegalizeCommon and TosaLegalizeUtilsThe code is largely framework invariant. The few places where framework-implemented support functions are used will be replaced with standalone code. The Torch side only has the reduction operator support functions right now, but the rest will progressively be used as new legalizations are added on the Torch side.
With these functions moving into core, it enables a common level of service for TOSA legalization paths from multiple frameworks, carrying shape inference and similar capabilities. The related TosaInferShapes pass is already in the core.
The feedback was that we needed to test these functions in core by adding a test dialect that would define ops that match the exposed APIs, and a conversion pass from this test dialect to TOSA.
To this end, this RFC describes the TosaTest Dialect and the conversion pass from TosaTest to TOSA, so that we can validate the TOSA construction lib in core.
Example walk through: buildSelectOp
One of the simple TOSA construction lib function is buildSelectOp. This is used in converting tensorflow-lite SelectOp and SelectV2Op, as well as Tensorflow SelectV2Op, to TOSA.
The function checks the shape of the condition_value and, if needed, constructs a TOSA Reshape Op to reshape the condition_value to match the ranks of the x_value and y_value inputs.
It then constructs a TOSA Select Op from the (reshaped) condition, x and y values.
The function returns the output value of the resulting TOSA subgraph (or null if error), which allows a caller to use a rewriter to replace the original op with this TOSA subgraph.
To test this function fully within the MLIR codebase, we
- create a new dialect, TosaTest Dialect that contains a Select Op
- implement a conversion pass from TosaTest to TOSA which, among other things, matches and rewrites TosaTest Select Op to TOSA subgraph by calling buildSelectOp function
- create a lit test that calls the above legalization pass to verify the conversion from TosaTest Select op to TOSA subgraph.
Both the TosaTest dialect, and the conversion pass are in the mlir/test tree and therefore are for testing only.
TOSA construction lib:
- will be located in core at:
- mlir/include/mlir/Dialect/Tosa/Utils/: CommonBuilders.h and CommonBuildUtils.h
- mlir/lib/Dialect/Tosa/Utils/: CommonBuilders.cpp and CommonBuildUtils.cpp
- initially, contains only buildPackOp, buildUnpackOp, and buildSelectOp and their dependent utility functions.
- will add other build function later on
- the buildSelectOp API in CommonBuilder.h is:
// Build a TOSA subgraph for the Select operator.
std::optional<Value> buildSelectOp(OpBuilder &builder, const Location &loc,
Value result_value, Value condition_value,
Value x_value, Value y_value);
TosaTest Dialect:
- will be located in mlir/test/lib/Dialect/Tosa/TosaTestDialect.td:
def TosaTest_Dialect : Dialect {
let name = "tosa_test";
let description = [{
The Tosa Test dialect.
This dialect defines operations for use in testing
of Tosa helper functions for legalizing into TOSA.
}];
let cppNamespace = "::mlir::tosa_test";
}
- mlir/test/lib/Dialect/Tosa/TosaTestOps.td: (only Select op is shown)
include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
include "TosaTestDialect.td"
class TosaTest_Op<string mnemonic, list<Trait> traits = []> :
Op<TosaTest_Dialect, mnemonic, traits>;
def TosaTest_SelectOp : TosaTest_Op<"select"> {
let summary = "Select operator.";
let description = [{ ... }];
let arguments = (ins
I1Tensor:$condition,
Tosa_Tensor:$x,
Tosa_Tensor:$y
);
let results = (outs
Tosa_Tensor:$output
);
}
A pass that converts from TosaTest to TOSA:
- mlir/test/lib/Dialect/Tosa/LegalizeTosaTestPasses.td:
def TosaLegalizeTosaTestPass : Pass<"tosa-legalize-tosa-test", "mlir::func::FuncOp"> {
let summary = "Legalize from TosaTest to TOSA";
let dependentDialects = ["TosaDialect",
"mlir::tosa_test::TosaTestDialect"];
let constructor = "mlir::tosa_test::createTosaLegalizeTosaTestPass()";
}
- mlir/test/lib/Dialect/Tosa/LegalizeTosaTest.cpp:
- contains the following matchAndRewrite function for TosaTest Select Op:
LogicalResult ConvertTosaTestSelectOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tosa_test_sel_op = cast<tosa_test::SelectOp>(op);
std::optional<Value> result =
buildSelectOp(rewriter, op->getLoc(), tosa_test_sel_op.getResult(),
tosa_test_sel_op.getCondition(), tosa_test_sel_op.getX(),
tosa_test_sel_op.getY());
if (!result) return failure();
rewriter.replaceOp(op, {result.value()});
return success();
}
Lit Tests for the construction lib:
- is located in mlir/test/Dialect/Tosa/legalize-tosa-test.mlir (only one test is shown):
// RUN: mlir-opt --tosa-legalize-tosa-test %s | FileCheck %s
// -----
// CHECK-LABEL: test_select
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1>} : (tensor<1xi1>) -> tensor<1x1x1xi1>
// CHECK: %[[VAR2:.*]] = "tosa.select"(%[[VAR1]], %arg0, %arg1)
func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> {
%0 = "tosa_test.select"(%arg2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
func.return %0 : tensor<13x21x3xf32>
}