Not able to replace an MLIR operation by a new one inside a PASS

Hello,

I am playing around with MLIR pass infra and writing a constant folding pass to transform my own experimental dialect called cijie. This requires me to replace operations.

Below is an example MLIR file where there is an operation I defined called cijie.add.

module {
    func @bar() -> i32 {
        %0 = constant 1 : i32
        %1 = constant 2 : i32
        %res_1 = cijie.add %1 %0 : i32
        return %res_1 : i32
    }
} }

If this pass works properly, the IR should be transformed into

module {
    func @bar() -> i32 {
        %0 = constant 1 : i32
        %1 = constant 2 : i32
        %res_1 = constant 3 : i32
        return %res_1 : i32
    }
} }

However, I encounted an issue when replacing the cijie.add operation by a constant operation. The cijie.add was not replaced by the constant operation with value 3. In fact, the constant operation I created failed to appear and resulted in an invalid IR.

Here is the error IR.

module  {
  func @bar() -> i32 {
    %c1_i32 = constant 1 : i32
    %c2_i32 = constant 2 : i32
    return <<UNKNOWN SSA VALUE>> : i32
  }
}

Here is the implementation detail.

Inside that pass I traverse the IR operations and save all the cijie.add operations that can be folded into the constants. When finishing traversing the IR I create constant operations and replace those cijie.add operations by those constant operations I created.

I create a ConstantIntOp using OpBuilder::create :

auto constantOperation = opBuilder.create<ConstantIntOp>(f.op->getLoc(), f.value, 32);

I replace the cijie.add operations using IRRewriter::replaceOp:

rewriter.replaceOp(f.op, {constantOperation});

Above are the two key steps that I believe will do the replacement.

Here is the full code snippet I wrote inside an opt file.

struct OperationFold {
    Operation* op;
    int32_t value;
};

//The pass for constant folding
struct ConstantFolding : public PassWrapper<ConstantFolding, OperationPass<>> {

    StringRef getArgument() const final { return "constant-folding"; }
    StringRef getDescription() const final { return "Constant Folding"; }

    std::queue<OperationFold> operationsToFold;

    void runOnOperation() override {

        Operation* op = getOperation(); //root op

        llvm::outs() << "IR before pass\n";
        llvm::outs() << *op << "\n";

        visitOperation(op);

        //folding the operations into constants.
        OpBuilder opBuilder(&getContext());
        IRRewriter rewriter(&getContext());

        while (!operationsToFold.empty()) {
            OperationFold f = operationsToFold.front();
            operationsToFold.pop();
            auto constantOperation = opBuilder.create<ConstantIntOp>(f.op->getLoc(), f.value, 32);
            rewriter.replaceOp(f.op, {constantOperation});
        }

        llvm::outs() << "IR after pass\n";
        llvm::outs() << *op << "\n";
    }

    void visitOperation(Operation* op) {

        //test if this operation is an add operation
        if (isa_impl<AddOp, Operation>::doit(*op)) {
            mlir::Value lhs = op->getOperand(0);
            mlir::Value rhs = op->getOperand(1);
            auto lhs_c = lhs.getDefiningOp<ConstantIntOp>(); //left operator of addition
            auto rhs_c = rhs.getDefiningOp<ConstantIntOp>();
            if (lhs_c && rhs_c) { // both constants

                OperationFold fold;
                fold.op = op;
                fold.value = lhs_c.getValue() + lhs_c.getValue();
                operationsToFold.push(fold);
            }
        }

        for (Region &region : op->getRegions())
            visitRegion(region);

    }

    void visitRegion(Region &region) {
        for (Block &block : region.getBlocks())
            visitBlock(block);
    }

    void visitBlock(Block &block) {
        for (Operation &op : block.getOperations()) {
            visitOperation(&op);
        }

    }


};

int main(int argc, char **argv) {

  mlir::DialectRegistry registry;
  registry.insert<mlir::cijie::CijieDialect>();
  registry.insert<mlir::StandardOpsDialect>();

  PassRegistration<ConstantFolding>();
  mlir::MLIRContext context;
  mlir::PassManager pm(&context);
  pm.addPass(std::make_unique<ConstantFolding>());;

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "Cijie optimizer driver\n", registry));
}

Can you point out where I am wrong. It has confused me for days!

Thanks!

You seem to have not configured the insertion point of the builder/rewriter, so it doesn’t put the new operation in the block where you expect it to be. It has no way of knowing where it is. The new operation is created “in the thin air”, and the value defined by it cannot be used in return because of SSA visibility rules, leading to invalid IR that you see.

Do something like OpBuilder builder(f.op); to or builder.setInsertionPoint(f.op) to insert operations before op.

Irrespective of the above, using separately created OpBuilder and IRRewriter is an amazing footgun. Create the rewriter from the builder instead. Also, IRRewriter is a builder, all builder methods can be called directly on it. Don’t use isa_impl, as its name indicates, it’s an implementation detail, use isa instead.

Finally, MLIR has first-class support for folding and canonicalization, so I wouldn’t bother writing a constant folding pass beyond purely educational purposes.

Thanks! The issue is solved.
Yes, I am getting myself familiar with MLIR so I wrote this constant folding pass.