The code is based on MLIR 78a09cbd3e2890b5cd03cd66b2cc98d83811a728 (May 8, 2023)
Our project haven’t update MLIR for a while.
I want to rewrite the following function
to combine two function arguments into one tuple function arguments
Before
// hbir.add is a custom Op
func.func @foo(%arg0: tensor<1x300x300x1xui8>, %arg1: tensor<1x300x300x1xui8>) -> (tensor<1x300x300x3xui8>) {
%0 = "hbir.add"(%arg0, %arg1): (tensor<1x300x300x1xui8>, tensor<1x300x300x1xui8>) -> tensor<1x300x300x3xui8>
return %0 : tensor<1x300x300x3xui8>
}
After
// hbir.tuple_get is a custom Op
func.func @foo(%arg0: tuple<tensor<1x300x300x1xui8>, tensor<1x300x300x1xui8>>) -> (tensor<1x300x300x3xsi8>) {
%0 = "hbir.tuple_get"(%arg0): {index = 0:i32} (tensor<1x300x300x1xui8>) -> tensor<1x300x300x3xui8>
%1 = "hbir.tuple_get"(%arg1): {index = 1:i32} (tensor<1x300x300x1xui8>) -> tensor<1x300x300x3xui8>
%2 = "hbir.add"(%arg0, %arg1): (tensor<1x300x300x1xui8>, tensor<1x300x300x1xui8>) -> tensor<1x300x300x3xui8>
return %2 : tensor<1x300x300x3xui8>
}
Here is my pattern rewriter code
It can correctly print transfomed function in LLVM_DEBUG
at the end of rewriter function.
But it crashes later after the pass that calls this pattern (in the dtor of MLIRContext). I do not what I did wrong.
Note that this is the only pass that I called in the program that I am running a single pass test
class MakeTuple : public OpConversionPattern<func::FuncOp> {
public:
explicit MakeTuple(MLIRContext *ctx, PatternBenefit benefit = 1)
: OpConversionPattern<func::FuncOp>(ctx, benefit) {}
LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override;
}
LogicalResult MakeTuple::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
size_t startIndex = 0;
size_t num = 2;
if (funcOp.getNumArguments() != num) {
return failure();
}
if (funcOp.getArgument(startIndex).getType().isa<TupleType>()) {
return failure();
}
llvm::SmallVector<mlir::Type> tupleChildren;
for (size_t i = startIndex; i < startIndex + num; ++i) {
auto value = funcOp.getArgument(i);
auto type = value.getType();
tupleChildren.push_back(type);
}
auto tupleType = mlir::TupleType::get(funcOp->getContext(), tupleChildren);
auto oldAttr = funcOp.getArgAttrDict(startIndex);
auto oldLoc = funcOp.getArgument(startIndex).getLoc();
rewriter.updateRootInPlace(funcOp, [&]() {
funcOp.insertArgument(startIndex, tupleType, oldAttr, oldLoc);
});
auto tupleValue = funcOp.getArgument(startIndex);
rewriter.setInsertionPointToStart(&funcOp.front());
for (size_t i = 0; i < num; ++i) {
auto getTupleElementOp = rewriter.create<hbir::TupleGetOp>(tupleValue.getLoc(), tupleType.getType(i), tupleValue,
static_cast<uint32_t>(i));
rewriter.replaceAllUsesWith(funcOp.getArgument(startIndex + i + 1), getTupleElementOp);
}
rewriter.updateRootInPlace(funcOp, [&]() {
llvm::BitVector argumentsToRemove(funcOp.getNumArguments());
argumentsToRemove.set(startIndex + 1, startIndex + 1 + num);
funcOp.eraseArguments(argumentsToRemove);
});
LLVM_DEBUG(llvm::dbgs() << "Op after transform " << funcOp << "\n");
return success();
}