Hello,
I am playing around with MLIR pass infra and writing a constant folding pass to transform my own experimental dialect called cijie
. This requires me to replace operations.
Below is an example MLIR file where there is an operation I defined called cijie.add
.
module {
func @bar() -> i32 {
%0 = constant 1 : i32
%1 = constant 2 : i32
%res_1 = cijie.add %1 %0 : i32
return %res_1 : i32
}
} }
If this pass works properly, the IR should be transformed into
module {
func @bar() -> i32 {
%0 = constant 1 : i32
%1 = constant 2 : i32
%res_1 = constant 3 : i32
return %res_1 : i32
}
} }
However, I encounted an issue when replacing the cijie.add
operation by a constant
operation. The cijie.add
was not replaced by the constant operation with value 3. In fact, the constant operation I created failed to appear and resulted in an invalid IR.
Here is the error IR.
module {
func @bar() -> i32 {
%c1_i32 = constant 1 : i32
%c2_i32 = constant 2 : i32
return <<UNKNOWN SSA VALUE>> : i32
}
}
Here is the implementation detail.
Inside that pass I traverse the IR operations and save all the cijie.add
operations that can be folded into the constants. When finishing traversing the IR I create constant operations and replace those cijie.add
operations by those constant operations I created.
I create a ConstantIntOp
using OpBuilder::create :
auto constantOperation = opBuilder.create<ConstantIntOp>(f.op->getLoc(), f.value, 32);
I replace the cijie.add
operations using IRRewriter::replaceOp:
rewriter.replaceOp(f.op, {constantOperation});
Above are the two key steps that I believe will do the replacement.
Here is the full code snippet I wrote inside an opt file.
struct OperationFold {
Operation* op;
int32_t value;
};
//The pass for constant folding
struct ConstantFolding : public PassWrapper<ConstantFolding, OperationPass<>> {
StringRef getArgument() const final { return "constant-folding"; }
StringRef getDescription() const final { return "Constant Folding"; }
std::queue<OperationFold> operationsToFold;
void runOnOperation() override {
Operation* op = getOperation(); //root op
llvm::outs() << "IR before pass\n";
llvm::outs() << *op << "\n";
visitOperation(op);
//folding the operations into constants.
OpBuilder opBuilder(&getContext());
IRRewriter rewriter(&getContext());
while (!operationsToFold.empty()) {
OperationFold f = operationsToFold.front();
operationsToFold.pop();
auto constantOperation = opBuilder.create<ConstantIntOp>(f.op->getLoc(), f.value, 32);
rewriter.replaceOp(f.op, {constantOperation});
}
llvm::outs() << "IR after pass\n";
llvm::outs() << *op << "\n";
}
void visitOperation(Operation* op) {
//test if this operation is an add operation
if (isa_impl<AddOp, Operation>::doit(*op)) {
mlir::Value lhs = op->getOperand(0);
mlir::Value rhs = op->getOperand(1);
auto lhs_c = lhs.getDefiningOp<ConstantIntOp>(); //left operator of addition
auto rhs_c = rhs.getDefiningOp<ConstantIntOp>();
if (lhs_c && rhs_c) { // both constants
OperationFold fold;
fold.op = op;
fold.value = lhs_c.getValue() + lhs_c.getValue();
operationsToFold.push(fold);
}
}
for (Region ®ion : op->getRegions())
visitRegion(region);
}
void visitRegion(Region ®ion) {
for (Block &block : region.getBlocks())
visitBlock(block);
}
void visitBlock(Block &block) {
for (Operation &op : block.getOperations()) {
visitOperation(&op);
}
}
};
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
registry.insert<mlir::cijie::CijieDialect>();
registry.insert<mlir::StandardOpsDialect>();
PassRegistration<ConstantFolding>();
mlir::MLIRContext context;
mlir::PassManager pm(&context);
pm.addPass(std::make_unique<ConstantFolding>());;
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Cijie optimizer driver\n", registry));
}
Can you point out where I am wrong. It has confused me for days!
Thanks!