OpRewritePattern works weird

Hi All,

I built a simple graph to do some transformation on layernorm op, to make things work, I set the layernorm “let hasCanonicalizer = 1” to bind this pass on MLIR builtin Canonicalization pass.

And I also have the following code to make things work:

struct QuantizedLayerNormRewrite : public OpRewritePattern {
using OpRewritePattern::OpRewritePattern;
QuantizedLayerNormRewrite(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit) {}

LogicalResult match(QuantizedLayerNormOp op) const override{
if (llvm::dyn_cast(op.getOperation()))
return success();
else
return failure();
}

void rewrite(QuantizedLayerNormOp op, PatternRewriter &rewriter) const override {
if (succeeded(match(op)))
std::cout << “matched layernorm!!!” << op.nodeName().str()
<< std::endl;
}

// LogicalResult matchAndRewrite(QuantizedLayerNormOp op, PatternRewriter &rewriter) const override {
// std::cout << “matched layernorm!!!” << std::endl;
// }
};

void QuantizedLayerNormOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
results.add(context);
}

When I ran the code, the pattern of layernorm has been matched many times. And after doing some analysis, I think the times that layernorm rewrite pattern run just as the number of Operation of the graph. As you can see the log:
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12
matched layernorm!!!Add_12

module {
func.func @main_graph(%arg0: !xrt.xrtensor<1x9800x256xi8, “quant_tensor_input”, fmt: “MK128”, dst_fmt: “K16M32”, “MATRIX_IFM_PARAM”, inQuant: [63.5000076, ], size: [(0, 0),(0, 9799),(0, 255),]>) → !xrt.xrtensor<9800x256xi8, “dequant_tensor_output, dst_fmt: “C128”, outDequant: [0.0394748934, ], size: [(0, 9799),(0, 255),]> attributes {input_names = [“quant_tensor_input”], output_names = [“dequant_tensor_output”]} {
%0 = “xrt.NoValue”() {value} : () → none
%1 = “xrt.Constant”() {value = dense<0> : tensor<256x256xi8>} : () → !xrt.xrtensor<1x256x256xi8, “37”, fmt: “MATRIX_KERNEL_INT8”, “MATRIX_KERNEL_PARAM”, size: [(0, 0),(0, 255),(0, 255),]>
%2 = “xrt.Constant”() {value = dense<0> : tensor<256x1xi32>} : () → !xrt.xrtensor<256x1xi32, “MatMul_0_quant_params”, fmt: “QUANT”, “QUANT_PARAM”, size: [(0, 255),]>
%3 = “xrt.QMatMul”(%arg0, %1, %0, %0, %2, %0, %0, %0) {XPRTOpType = 12 : i64, act = “”, bias_term = false, config = #xrt.config<NPU_num: 1, DeployCase: “no_slice”, ScanlineIdx: 0, sliceNum: 3, sliceKernelNum: 1, segmentId: 0, srcBuffer: “L1”, ofmBuffer: “L1”>, nodeName = “MatMul_0”, taskInfos = [”{\22allocFlag\22:1,\22biasName\22:\22\22,\22deployMode\22:1,\22dequantName\22:\22\22,\22eltName\22:\22\22,\22engineDeployTask\22:1,\22ifmName\22:[\22quant_tensor_input\22],\22ifmSize\22:[[0,0,0,9799,0,255]],\22kernelName\22:\2237\22,\22kernelSize\22:[0,0,0,255,0,255],\22lastNPU\22:true,\22lutName\22:\22\22,\22lutQuantName\22:\22\22,\22npuId\22:0,\22ofmBuffer\22:4,\22ofmDstBuffer\22:1521064096,\22ofmName\22:\2224\22,\22ofmSize\22:[0,0,0,9799,0,255],\22padding\22:[0,0,0,0],\22quantName\22:\22MatMul_0_quant_params\22,\22quantSize\22:[0,255],\22scanlineId\22:0,\22srcBuffer\22:4,\22taskInfoId\22:4294967295}"], veuFuse = } : (!xrt.xrtensor<1x9800x256xi8, “quant_tensor_input”, fmt: “MK128”, dst_fmt: “K16M32”, “MATRIX_IFM_PARAM”, inQuant: [63.5000076, ], size: [(0, 0),(0, 9799),(0, 255),]>, !xrt.xrtensor<1x256x256xi8, “37”, fmt: “MATRIX_KERNEL_INT8”, “MATRIX_KERNEL_PARAM”, size: [(0, 0),(0, 255),(0, 255),]>, none, none, !xrt.xrtensor<256x1xi32, “MatMul_0_quant_params”, fmt: “QUANT”, “QUANT_PARAM”, size: [(0, 255),]>, none, none, none) → !xrt.xrtensor<1x9800x256xi8, “24”, fmt: “K32M32”, “MATRIX_OFM_PARAM”, size: [(0, 0),(0, 9799),(0, 255),]>
%4 = “xrt.Constant”() {value = dense<0> : tensor<256xi32>} : () → !xrt.xrtensor<256xi32, "layernorm.bias, size: [(0, 255),]>
%5 = “xrt.Constant”() {value = dense<0> : tensor<1x1024xi8>} : () → !xrt.xrtensor<1x1024xi8, "Add_12_lut_params, size: [(0, 0),(0, 1023),]>
%6 = “xrt.Constant”() {value = dense<0> : tensor<9800x1xi32>} : () → !xrt.xrtensor<9800x1xi32, "Add_12_quant_divisor_params, size: [(0, 9799),(0, 0),]>
%7 = “xrt.Constant”() {value = dense<0> : tensor<9800x1xi32>} : () → !xrt.xrtensor<9800x1xi32, "Add_12_quant_mean_params, size: [(0, 9799),(0, 0),]>
%8 = “xrt.Constant”() {value = dense<0> : tensor<256x1xi32>} : () → !xrt.xrtensor<256x1xi32, "Add_12_quant_out_params, size: [(0, 255),(0, 0),]>
%9 = “xrt.QuantizedLayerNorm”(%3, %4, %0, %0, %0, %0, %0, %5, %6, %7, %8) {XPRTOpType = 71 : i64, divisor = 21 : i32, nodeName = “Add_12”, remain_bits = 5 : i32} : (!xrt.xrtensor<1x9800x256xi8, “24”, fmt: “K32M32”, “MATRIX_OFM_PARAM”, size: [(0, 0),(0, 9799),(0, 255),]>, !xrt.xrtensor<256xi32, "layernorm.bias, size: [(0, 255),]>, none, none, none, none, none, !xrt.xrtensor<1x1024xi8, "Add_12_lut_params, size: [(0, 0),(0, 1023),]>, !xrt.xrtensor<9800x1xi32, "Add_12_quant_divisor_params, size: [(0, 9799),(0, 0),]>, !xrt.xrtensor<9800x1xi32, "Add_12_quant_mean_params, size: [(0, 9799),(0, 0),]>, !xrt.xrtensor<256x1xi32, "Add_12_quant_out_params, size: [(0, 255),(0, 0),]>) → !xrt.xrtensor<9800x256xi8, "dequant_tensor_output, dst_fmt: “C128”, outDequant: [0.0394748934, ], size: [(0, 9799),(0, 255),]>
return %9 : !xrt.xrtensor<9800x256xi8, "dequant_tensor_output, dst_fmt: “C128”, outDequant: [0.0394748934, ], size: [(0, 9799),(0, 255),]>
}
}

As you can see, there is only one layernorm in the graph, but the pattern was matched 11 times. But the graph has 11 SSA sentences. So I think I just make a mistake of the usage of rewriter? Can anybody help me? I have been struggling for days.

The canonicalization pass repeatedly applies patterns until the IR reaches a fixed point; the pass will keep applying your pattern until match returns failure(), which never happens since you always return success() and never transform the op. The only reason the pass stops is because there’s an upper limit to the number of iterations it will do, which is where the 11 comes from. You’ll probably find these useful:

1 Like

Wow, Thank you for your quick response, bro. Your explaination is very helpful!
I’ll have a try of your suggestion.

Hi troggo,
Follow your suggestion, I managed to transform layernorm op into a constant op. However, I found other Ops like NoneOp, QMatMul in the graph are gone!!!

module {
func.func @main_graph(%arg0: !xrt.xrtensor<1x9800x256xi8, “quant_tensor_input”, fmt: “MK128”, dst_fmt: “K16M32”, “MATRIX_IFM_PARAM”, inQuant: [63.5000076, ], size: [(0, 0),(0, 9799),(0, 255),]>) → !xrt.xrtensor<9800x256xi8, "dequant_tensor_output, dst_fmt: “C128”, outDequant: [0.0394748934, ], size: [(0, 9799),(0, 255),]> attributes {input_names = [“quant_tensor_input”], output_names = [“dequant_tensor_output”]} {
%0 = “xrt.Constant”() {value = dense<0> : tensor<256x256xi8>} : () → !xrt.xrtensor<1x256x256xi8, “37”, fmt: “MATRIX_KERNEL_INT8”, “MATRIX_KERNEL_PARAM”, size: [(0, 0),(0, 255),(0, 255),]>
%1 = “xrt.Constant”() {value = dense<0> : tensor<256x1xi32>} : () → !xrt.xrtensor<256x1xi32, “MatMul_0_quant_params”, fmt: “QUANT”, “QUANT_PARAM”, size: [(0, 255),]>
%2 = “xrt.Constant”() {value = dense<0> : tensor<256xi32>} : () → !xrt.xrtensor<256xi32, "layernorm.bias, size: [(0, 255),]>
%3 = “xrt.Constant”() {value = dense<0> : tensor<1x1024xi8>} : () → !xrt.xrtensor<1x1024xi8, "Add_12_lut_params, size: [(0, 0),(0, 1023),]>
%4 = “xrt.Constant”() {value = dense<0> : tensor<9800x1xi32>} : () → !xrt.xrtensor<9800x1xi32, "Add_12_quant_divisor_params, size: [(0, 9799),(0, 0),]>
%5 = “xrt.Constant”() {value = dense<0> : tensor<9800x1xi32>} : () → !xrt.xrtensor<9800x1xi32, "Add_12_quant_mean_params, size: [(0, 9799),(0, 0),]>
%6 = “xrt.Constant”() {value = dense<0> : tensor<256x1xi32>} : () → !xrt.xrtensor<256x1xi32, "Add_12_quant_out_params, size: [(0, 255),(0, 0),]>
%7 = “xrt.Constant”() {value = dense<0> : tensor<9800x256xi8>} : () → !xrt.xrtensor<9800x256xi8, "dequant_tensor_output, dst_fmt: “C128”, outDequant: [0.0394748934, ], size: [(0, 9799),(0, 255),]>
return %7 : !xrt.xrtensor<9800x256xi8, "dequant_tensor_output, dst_fmt: “C128”, outDequant: [0.0394748934, ], size: [(0, 9799),(0, 255),]>
}
}

And my code for layernorm rewrite as follows:
struct QuantizedLayerNormRewrite : public OpRewritePattern {
using OpRewritePattern::OpRewritePattern;
QuantizedLayerNormRewrite(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit) {}

LogicalResult matchAndRewrite(QuantizedLayerNormOp op, PatternRewriter &rewriter) const override {
auto eleType = op.getResult().getType().cast().getElementType();
auto shape = op.getResult().getType().cast().getShape();
auto tensorType = mlir::RankedTensorType::get(shape, eleType);
auto sType = op.getResult().getType().cast();
std::vector<int8_t> data(1); // TODO: constant value
auto denseElmAttr =
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(data));

std::vector<mlir::NamedAttribute> attrs;
attrs.push_back(rewriter.getNamedAttr(mlir::StringRef("value"), denseElmAttr));

rewriter.replaceOpWithNewOp<ConstantOp>(op, sType, ValueRange{}, attrs);
std::cout << "matched layernorm!!!!!!" << std::endl;
return success();

}
};

How can this happen to make other operations except constant op disappear?

The greedy rewrite driver erases dead ops as well irrespective of what patterns you have in the pattern list. I think the documentation at Pattern Rewriting : Generic DAG-to-DAG Rewriting - MLIR needs to be improved/made more complete to mention that.

@bondhugula
Thank you for your kind suggestion!How can I avoid the so-called dead nodes being erased?

You can’t using this driver, why do you need to keep dead operations around?

@mehdi_amini Thank you for answering! Actually NoneOp and QuantizedMatMul Op are not dead nodes for me, I only want to transform layernorm Op and keep every other Op unchanged.

But fortunately, I found there was an issue about this usage, and I only have to do is remove the NoSideEffect label in my dialect Op definition in .td file.

Thank you anyway!