Crash when using Pattern Rewriter to update function arguments

The code is based on MLIR 78a09cbd3e2890b5cd03cd66b2cc98d83811a728 (May 8, 2023)
Our project haven’t update MLIR for a while.

I want to rewrite the following function
to combine two function arguments into one tuple function arguments

Before

// hbir.add is a custom Op
func.func @foo(%arg0: tensor<1x300x300x1xui8>, %arg1: tensor<1x300x300x1xui8>) -> (tensor<1x300x300x3xui8>) {
  %0 = "hbir.add"(%arg0, %arg1): (tensor<1x300x300x1xui8>, tensor<1x300x300x1xui8>) ->  tensor<1x300x300x3xui8>
  return %0 : tensor<1x300x300x3xui8>
}

After

// hbir.tuple_get is a custom Op

func.func @foo(%arg0: tuple<tensor<1x300x300x1xui8>, tensor<1x300x300x1xui8>>) -> (tensor<1x300x300x3xsi8>) {
  %0 = "hbir.tuple_get"(%arg0): {index = 0:i32} (tensor<1x300x300x1xui8>) -> tensor<1x300x300x3xui8>
  %1 = "hbir.tuple_get"(%arg1): {index = 1:i32} (tensor<1x300x300x1xui8>) -> tensor<1x300x300x3xui8>
  %2 = "hbir.add"(%arg0, %arg1): (tensor<1x300x300x1xui8>, tensor<1x300x300x1xui8>) ->  tensor<1x300x300x3xui8>
  return %2 : tensor<1x300x300x3xui8>
}

Here is my pattern rewriter code
It can correctly print transfomed function in LLVM_DEBUG at the end of rewriter function.
But it crashes later after the pass that calls this pattern (in the dtor of MLIRContext). I do not what I did wrong.
Note that this is the only pass that I called in the program that I am running a single pass test

class MakeTuple : public OpConversionPattern<func::FuncOp> {
public:
  explicit MakeTuple(MLIRContext *ctx, PatternBenefit benefit = 1)
      : OpConversionPattern<func::FuncOp>(ctx, benefit) {}
    LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override;
}

LogicalResult MakeTuple::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
                                                                ConversionPatternRewriter &rewriter) const {
  size_t startIndex = 0;
  size_t num = 2;
  if (funcOp.getNumArguments() != num) {
    return failure();
  }
  if (funcOp.getArgument(startIndex).getType().isa<TupleType>()) {
    return failure();
  }
  llvm::SmallVector<mlir::Type> tupleChildren;
  for (size_t i = startIndex; i < startIndex + num; ++i) {
    auto value = funcOp.getArgument(i);
    auto type = value.getType();
    tupleChildren.push_back(type);
  }
  auto tupleType = mlir::TupleType::get(funcOp->getContext(), tupleChildren);

  auto oldAttr = funcOp.getArgAttrDict(startIndex);
  auto oldLoc = funcOp.getArgument(startIndex).getLoc();
  rewriter.updateRootInPlace(funcOp, [&]() {
    funcOp.insertArgument(startIndex, tupleType, oldAttr, oldLoc);
  });
  auto tupleValue = funcOp.getArgument(startIndex);
  rewriter.setInsertionPointToStart(&funcOp.front());
  for (size_t i = 0; i < num; ++i) {
    auto getTupleElementOp = rewriter.create<hbir::TupleGetOp>(tupleValue.getLoc(), tupleType.getType(i), tupleValue,
                                                               static_cast<uint32_t>(i));
    rewriter.replaceAllUsesWith(funcOp.getArgument(startIndex + i + 1), getTupleElementOp);
  }
  rewriter.updateRootInPlace(funcOp, [&]() {
    llvm::BitVector argumentsToRemove(funcOp.getNumArguments());
    argumentsToRemove.set(startIndex + 1, startIndex + 1 + num);
    funcOp.eraseArguments(argumentsToRemove);
  });
  LLVM_DEBUG(llvm::dbgs() << "Op after transform " << funcOp << "\n");
  return success();
}

This looks like a memory corruption of some sort. You can try running under address sanitizer, though it sometimes doesn’t detect the problems with rewrite patterns due to pointer reuse.

My suspicion here is that the problem is somewhere around

    funcOp.insertArgument(startIndex, tupleType, oldAttr, oldLoc);

that bypasses the rewriter and directly mutates the IR. Direct mutation almost always leads to tricky issues when used in conversion patterns and should be avoided at all cost. Or this

 auto tupleValue = funcOp.getArgument(startIndex);

that takes the value from the operation being rewritten, instead of the adaptor that should be used. Generally, new IR must not use values that are not (transitively) coming form the adaptor as they may point to the old IR that will get erased. The code takes a value from the list of function arguments and later erases some arguments, which may trigger some invalidation.

Also note that updateRootInPlace doesn’t track changes to regions of the op being updated. And adding/removing a function argument also changes the entry block of the function body region.

Generally, prefer using signature conversion functions from TypeConverter to change function signatures. Argument materialization allows one to insert additional operations after signature conversion, so a direct rewrite pattern wouldn’t be necessary.

Thanks.
I use the same logic, but write the code in Pass without using Pattern Rewriter,
it works.

// The caller greedily calls this function many times,
// to simulate pattern matching
LogicalResult rewriteFuncInputToTuple(size_t startIndex, size_t num) {
    auto funcOp = getOperation();
    LLVM_DEBUG(llvm::dbgs() << "FuncOp before tuple transform " << funcOp << "\n");
    llvm::SmallVector<mlir::Type> tupleChildren;
    for (size_t i = startIndex; i < startIndex + num; ++i) {
      auto value = funcOp.getArgument(i);
      auto type = value.getType();
      tupleChildren.push_back(type);
    }
    auto tupleType = mlir::TupleType::get(funcOp->getContext(), tupleChildren);
    auto oldAttr = funcOp.getArgAttrDict(startIndex);
    auto oldLoc = funcOp.getArgument(startIndex).getLoc();
    funcOp.insertArgument(startIndex, tupleType, oldAttr, oldLoc);
    auto tupleValue = funcOp.getArgument(startIndex);

    mlir::OpBuilder builder(funcOp->getContext());
    builder.setInsertionPointToStart(&funcOp.front());
    for (size_t i = 0; i < num; ++i) {
      auto getTupleElementOp = builder.create<hbir::TupleGetOp>(tupleValue.getLoc(), tupleType.getType(i), tupleValue,
                                                                static_cast<uint32_t>(i));
      funcOp.getArgument(startIndex + i + 1).replaceAllUsesWith(getTupleElementOp);
    }
    llvm::BitVector argumentsToRemove(funcOp.getNumArguments());
    argumentsToRemove.set(startIndex + 1, startIndex + 1 + num);
    funcOp.eraseArguments(argumentsToRemove);
    LLVM_DEBUG(llvm::dbgs() << "FuncOp after tuple transform " << funcOp << "\n");
    return success();
}

I think my usage of PatternMatching is wrong.
No memory corruption in the code not shown in this post.

@ftynse Can you give some simple example how to do this using PatternRewriter?

I am confused that I do not see how to create my “tuple_get” op,
without funcOp.insertArgument,
because the input of tuple_get is a BlockArgument

Like I said above, direct IR mutation (i.e. not using the PatternRewriter object) is not allowed within the rewriting infra.

There is no simple example. The closest thing I can think of is the func-to-llvm conversion. It consists of a rewrite pattern: https://github.com/llvm/llvm-project/blob/3d51b40c4a4855a36a503d269c6c484404652949/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L473, a type converter rule for function signatures that change the type https://github.com/llvm/llvm-project/blob/3d51b40c4a4855a36a503d269c6c484404652949/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L251 and argument materialization for memref types https://github.com/llvm/llvm-project/blob/3d51b40c4a4855a36a503d269c6c484404652949/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L160-L176. Argument materialization is where you can insert your tuple_get the same way func-to-llvm inserts undef and insertelement to reconstruct the memref descriptor. You can read more about the signature changes that are performed during func-to-llvm here LLVM IR Target - MLIR and about the type converter set up here Dialect Conversion - MLIR. This is arguably the most complex piece of the infrastructure though.