I have been experimenting with the available MLIR Source Location information. I think this information, paired with LocationSnapshot , is useful to derive some context for the changes that occurred after a given MLIR Pass.
I created a prototype to demonstrate this here- Add IRMutations Util by vamsimanchala · Pull Request #2 · vamsimanchala/llvm-project · GitHub . Following are some details. I am interested in getting some early feedback from community and happy to send a more formal proposal if there is interest.
Objective:
To gather LLVM/MLIR community feedback on a simple approach (tool/utility), using MLIR Source Locations and LocationSnapshot, to derive a list of mutations/changes occurred to an MLIR entity(Op, Region, Block) after single Pass, a pass-pipeline or even after a PatternRewriter rewrite.
Introduction:
Debugging errors or unexpected behavior in MLIR Passes can be time consuming and un-intuitive. To enable users to debug these issues during execution of any compiler pipeline, it is important to preserve the transformation details and make them accessible/usable to the users(or any tool that can help with easy consumption of the details) . The fundamental problem that can be solved with preserving the transformation information is op provenance. In other words, users can benefit from being able to trace an op in a given dialect or pass, back to where it was in each transformation pass within the compiler pipeline, and back to its source.
Consider the following input MLIR Module-
input.mlir
module {
func.func @func_with_tf_add_op(%arg0 : tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = "tf.Const"() {value = dense<2.0> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "tf.AddV2"(%arg0, %cst) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
%5 = "tf.Identity"(%4) : (tensor<128x1xf32>) -> tensor<128x1xf32>
func.return %5: tensor<128x1xf32>
}
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> (tensor<128x1xf32>) {
%cst_0 = "tf.Const"() {value = dense<[[1.0], [2.0]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
%cst_1 = "tf.Const"() {value = dense<[128, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = "tf.Shape"(%arg0) : (tensor<128x2xf32>) -> tensor<2xi32>
%1 = "tf.Reshape"(%arg1, %0) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%2 = "tf.EnsureShape"(%1) {shape = #tf_type.shape<128x2>} : (tensor<128x2xf32>) -> tensor<128x2xf32>
%3 = "tf.BatchMatMul"(%2, %cst_0) : (tensor<128x2xf32>, tensor<2x1xf32>) -> tensor<128x1xf32>
%4 = func.call @func_with_tf_add_op(%3) : (tensor<128x1xf32>) -> tensor<128x1xf32>
func.return %4 : tensor<128x1xf32>
}
}
Running the tf-opt
tool to apply -inline='default-pipeline='''
, -tf-shape-inference
, -tfl-prepare-tf
, -tfl-legalize-tf
and -tfl-optimize
Passes(in that order) will produce the following output-
module {
func.func @func_with_tf_add_op(%arg0: tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = arith.constant dense<2.000000e+00> : tensor<1xf32>
%0 = tfl.add(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %0 : tensor<128x1xf32>
}
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> tensor<128x1xf32> {
%0 = "tfl.pseudo_const"() {value = dense<[[1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>} : () -> tensor<*xf32>
%cst = arith.constant dense<[128, 2]> : tensor<2xi32>
%cst_0 = arith.constant dense<2.000000e+00> : tensor<1xf32>
%1 = "tfl.reshape"(%arg1, %cst) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%2 = "tfl.fully_connected"(%1, %0, %cst_0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<128x2xf32>, tensor<*xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %2 : tensor<128x1xf32>
}
}
The input MLIR Module underwent a series of modifications as part of each of the above mentioned passes. It is possible, that one or more of the applied MLIR Passes may have made undesirable changes. One way to debug such issue is by printing the MLIR text to the terminal with the help of the available MLIR Tooling, like --mlir-print-ir-after-all
and --mlir-print-debuginfo
.
For example, this is what is printed to the terminal with the --mlir-print-ir-after-all
option set-
// -----// IR Dump After Inliner (inline) //----- //
module {
func.func @func_with_tf_add_op(%arg0: tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
%1 = "tf.Identity"(%0) : (tensor<128x1xf32>) -> tensor<128x1xf32>
return %1 : tensor<128x1xf32>
}
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> tensor<128x1xf32> {
%cst = "tf.Const"() {value = dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
%cst_0 = "tf.Const"() {value = dense<[128, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = "tf.Shape"(%arg0) : (tensor<128x2xf32>) -> tensor<2xi32>
%1 = "tf.Reshape"(%arg1, %0) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%2 = "tf.EnsureShape"(%1) {shape = #tf_type.shape<128x2>} : (tensor<128x2xf32>) -> tensor<128x2xf32>
%3 = "tf.BatchMatMul"(%2, %cst) : (tensor<128x2xf32>, tensor<2x1xf32>) -> tensor<128x1xf32>
%cst_1 = "tf.Const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "tf.AddV2"(%3, %cst_1) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
%5 = "tf.Identity"(%4) : (tensor<128x1xf32>) -> tensor<128x1xf32>
return %5 : tensor<128x1xf32>
}
}
// -----// IR Dump After TensorFlowShapeInferencePass (tf-shape-inference) //----- //
module {
func.func @func_with_tf_add_op(%arg0: tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
%1 = "tf.Identity"(%0) : (tensor<128x1xf32>) -> tensor<128x1xf32>
return %1 : tensor<128x1xf32>
}
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> tensor<128x1xf32> {
%cst = "tf.Const"() {value = dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
%cst_0 = "tf.Const"() {value = dense<[128, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = "tf.Shape"(%arg0) : (tensor<128x2xf32>) -> tensor<2xi32>
%1 = "tf.Reshape"(%arg1, %0) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%2 = "tf.EnsureShape"(%1) {shape = #tf_type.shape<128x2>} : (tensor<128x2xf32>) -> tensor<128x2xf32>
%3 = "tf.BatchMatMul"(%2, %cst) : (tensor<128x2xf32>, tensor<2x1xf32>) -> tensor<128x1xf32>
%cst_1 = "tf.Const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "tf.AddV2"(%3, %cst_1) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
%5 = "tf.Identity"(%4) : (tensor<128x1xf32>) -> tensor<128x1xf32>
return %5 : tensor<128x1xf32>
}
}
// -----// IR Dump After PrepareTFPass (tfl-prepare-tf) //----- //
func.func @func_with_tf_add_op(%arg0: tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = arith.constant dense<2.000000e+00> : tensor<1xf32>
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %0 : tensor<128x1xf32>
}
// -----// IR Dump After LegalizeTFPass (tfl-legalize-tf) //----- //
func.func @func_with_tf_add_op(%arg0: tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = arith.constant dense<2.000000e+00> : tensor<1xf32>
%0 = tfl.add(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %0 : tensor<128x1xf32>
}
// -----// IR Dump After OptimizePass (tfl-optimize) //----- //
func.func @func_with_tf_add_op(%arg0: tensor<128x1xf32>) -> tensor<128x1xf32> {
%cst = arith.constant dense<2.000000e+00> : tensor<1xf32>
%0 = tfl.add(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %0 : tensor<128x1xf32>
}
// -----// IR Dump After PrepareTFPass (tfl-prepare-tf) //----- //
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> tensor<128x1xf32> {
%cst = "tf.Const"() {value = dense<[[1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>} : () -> tensor<*xf32>
%cst_0 = arith.constant dense<[128, 2]> : tensor<2xi32>
%cst_1 = arith.constant dense<2.000000e+00> : tensor<1xf32>
%0 = "tf.Reshape"(%arg1, %cst_0) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%1 = "tf.MatMul"(%0, %cst) {transpose_a = false, transpose_b = true} : (tensor<128x2xf32>, tensor<*xf32>) -> tensor<128x1xf32>
%2 = "tf.AddV2"(%1, %cst_1) : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %2 : tensor<128x1xf32>
}
// -----// IR Dump After LegalizeTFPass (tfl-legalize-tf) //----- //
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> tensor<128x1xf32> {
%0 = "tfl.pseudo_const"() {value = dense<[[1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>} : () -> tensor<*xf32>
%cst = arith.constant dense<[128, 2]> : tensor<2xi32>
%cst_0 = arith.constant dense<2.000000e+00> : tensor<1xf32>
%1 = "tfl.reshape"(%arg1, %cst) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%2 = "tfl.no_value"() {value} : () -> none
%3 = "tfl.fully_connected"(%1, %0, %2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<128x2xf32>, tensor<*xf32>, none) -> tensor<128x1xf32>
%4 = tfl.add(%3, %cst_0) {fused_activation_function = "NONE"} : (tensor<128x1xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %4 : tensor<128x1xf32>
}
// -----// IR Dump After OptimizePass (tfl-optimize) //----- //
func.func @batchmatmul2fullyconnected(%arg0: tensor<128x2xf32>, %arg1: tensor<256xf32>) -> tensor<128x1xf32> {
%0 = "tfl.pseudo_const"() {value = dense<[[1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>} : () -> tensor<*xf32>
%cst = arith.constant dense<[128, 2]> : tensor<2xi32>
%cst_0 = arith.constant dense<2.000000e+00> : tensor<1xf32>
%1 = "tfl.reshape"(%arg1, %cst) : (tensor<256xf32>, tensor<2xi32>) -> tensor<128x2xf32>
%2 = "tfl.fully_connected"(%1, %0, %cst_0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<128x2xf32>, tensor<*xf32>, tensor<1xf32>) -> tensor<128x1xf32>
return %2 : tensor<128x1xf32>
}
While its possible to debug issues by manually inspecting the MLIR dumps, it would be better to know exactly what happened in a given Pass or a PatternRewriter, especially as the MLIR dumps are usually very large.
Use-case:
- Print a list of mutations after every pass or selected passes; print mutations made to a given Op, with the help of MLIR cli options like-
--mlir-print-ir-mutations-after-all
,--mlir-print-ir-mutations-after=<pass1, pass2>
or--mlir-print-op-mutations=<Op1, Op2>
, etc.
tf.BatchMatMul is
transformed
to tf.MatMul as a result of-tfl-prepare-tf
pass.tf.MatMul is
unfused
to form tfl.no_value and tfl.fully_connected as a result of-tfl-legalize-tf
pass.tfl.no_value, tfl.fully_connected and tfl.add
fused
to form tfl.fully_connectd as a result of-tfl-optimize
pass.
- Display mutations as MLIR-Diff
- Or pure graphical visualization purpose for Op traceability
Tool Interface:
This tool can be easily enabled with MLIR PassInstrumentation or MLIR Actions, to get the mutations after applying pass, pipeline or rewriter. Or additionally, this could be made part of the MLIRContext to extend the scope and the available information.
- Create a
mlir::IRMutationManager
with static storage. - Reset the state of the
IRMutationManager
prior to running the pass, pipeline or rewriter.
mutation_manager.reset(/*IRUnit*/ op, /*pass_name*/ transform_tag);
- Get the list of mutations after running the the pass, pipeline or rewriter, to print the list or supply to a DIFF Tool or Graph Visualizer.
auto mutations_list = mutationM=_manager.getMutations(op);
for (auto &&it : mutations_list) {
it->print();
}
- Perform a LocationSnapshot to reset/re-number the IR Source Locations, if there were mutations.
if (!mutations_list.empty()) {
if (failed(generateLocationsFromIR(pass_name, op, OpPrintingFlags(), tag)))
return signalPassFailure();
}
This PR modifies the LocationSnapshot.cc to demonstrate the tool use. Run -snapshot-op-locations
after every pass in the pipeline, like-
./tf-opt -snapshot-op-locations -inline=‘default-pipeline=’‘’ -snapshot-op-locations -tf-shape-inference -snapshot-op-locations -tfl-prepare-tf -snapshot-op-locations -tfl-legalize-tf -snapshot-op-locations -tfl-optimize -snapshot-op-locations input.mlir
Please note: LocationSnapshot is modified here only for demonstration purpose.