How to write RewritePattern to replace 3 ops of one dialect with single op of another dialect

I am trying to write Rewrite Pattern to replace 3 ops of one dialect with single op of another dialect.
For example,
%0=dialect1.op1 %arg0,%arg1
%1=dialect1.op2 %arg0,%arg2
%2=dialect1.op3 %0,%1

expected conversion pattern looks like:
%0=dialect2.newop %arg0,%arg1,%arg2

How do we achieve this in ConversionPattern?

How to provide 3 operations in in the same order to pattern matches against in the RewritePattern class?

Thanks

If op1 and op2 don’t have any side effect the most common solution is to add a RewritePattern matching the root operation (dialect1.op3 in your example) then follow the ssa links to detect that the sources come from op1 and op2 (and fail otherwise).
Then you can do a rewriter.replaceOpWithNewOp<dialect2::newOp>(op3, ...) and leave op2 and op1 to dead code elimination.

In general you’ll want to use ConversionPattern if the transformation from dialect1 to dialect2 requires type conversion otherwise using a simple RewritePattern is a simpler solution.

If op1/op2 have side effects then it is more complicated, let me know if this is your case.

Thank you.In my case op1 is multiply,op2 is compare and op3 is select

%2 = mhlo.multiply %arg0, %1 : tensor<4xf32>
%3 = “mhlo.compare”(%arg0, %cst) {comparison_direction = “GT”} : (tensor<4xf32>, tensor<4xf32>) → tensor<4xi1>
%4 = “mhlo.select”(%3, %arg0, %2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) → tensor<4xf32>

and I am trying to replace these ops with new op

as %0=dialect2.newop %arg0,%1,%cst

Make sense, in this case you should be able to write a simple RewritePattern looking for the mhlo.select with the right operands. and replacing it with the newop. You don’t need to worry about deleting mhlo.multiply and mhlo.compare as they will get removed by dead code elimination as long as they don’t have any other uses.
I hope this is clear enough, let me know if you need more details.

Actually, dialect conversion visits the operands before the uses. So it will never reach the select if the two others are illegal: it’ll fail early (depending on how it is configured, but I’ll keep it simple here).
So the general principle is to do rewriter.deleteOp(multiply_op) when matching the multiply, which is gonna be a no-op but indicates to the framework that it can safely continue to match the users. The conversion will fail is there is a user of the multiply left.
(same principle for the compare)

If we delete the multiply_op and compare op before dialect conversion then how can we get output of these to new op as argument?

When you do a destructive action in Dialect Conversion (erasing operations, replacing uses, etc.), it doesn’t actually do it immediately. It records the fact that you intend to do it, and keeps going. These recorded updates are only applied if the entire conversion process succeeds, otherwise they are all rolled back and discarded. So when you say rewriter.eraseOp(op), it doesn’t delete the operation; it just notes down oh okay, you are saying to me that this operation should be erased. I'll erase it for you at the end of the conversion (but you need to make sure all uses of it are removed by then, otherwise I will fail the conversion process)

– River

If I write rewrite class as below it will replace select op with new op ,But in this how can I get the input of multiply and compare as input to new op.

class ConvertMhloSelectToNew : public OpRewritePattern {
public:
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(NewOp select_op,
                                PatternRewriter& rewriter) const override {
    rewriter.replaceOpWithNewOpmlir::NewOp(select_op,
    select_op.getResult().getType(),
    select_op.operand(),
    select_op.operand1(),
    select_op.operand2());
    return success();
  }
};

and also how can I make sure it will replace select op only if its previous ops are mul and followed by compare op.

I’d think you’d write something like this:

class ConvertMhloSelectToNew : public OpRewritePattern {
public:
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(SelectOp select_op,
                                PatternRewriter& rewriter) const override {
    auto compareOp = select_op.pred().dyn_cast<CompareOp>();
    if (!compareOp) return notifyMatchFailure("Expect a CompareOp for the predicate");
    auto mulOp = select_op.on_false().dyn_cast<MulOp>();
    if (!mulOp) return notifyMatchFailure("Expect a MulOp for the false value");
    if (mulOp.lhs() != selectOp.on_true())
       return notifyMatchFailure("MulOp LHS does not match Select true value");
    if (mulOp.lhs() != compareOp.lhs())
       return notifyMatchFailure("MulOp LHS does not match Compare LHS");
    if (compareOp. comparison_direction() != "GT")
       return notifyMatchFailure("Compare direction isn't GT");

    rewriter.replaceOpWithNewOp<mlir::NewOp>(select_op, 
      select_op.getResult().getType(),
      mulOp.lhs(),
      mulOp.rhs(),
      compareOp.rhs());
    return success();
  }

};
1 Like