[dialect conversion] How to change the return type of `func.func`?

I want to convert

module {
func.func @const_shape() -> tensor<3xindex> {
  %shape = shape.const_shape [1, 2, 3] : tensor<3xindex>
  return %shape : tensor<3xindex>
}
}

into

module {
func.func @const_shape() -> i32 {
  %0 = "arith.constant"() <{value = 1 : i32}> : () -> i32
  %1 = "arith.constant"() <{value = 2 : i32}> : () -> i32
  %2 = "arith.addi"(%0, %1) : (i32, i32) -> i32
  return %2 : i32
}
}

It’s meaningless but simple enough to illstrute the conversion process. I write a conversion to create two arith.constant op and replace arith.addi op. But it seems not enough due to the dismatch with return type. What should I do to change the return type?
Here is my Conversion

namespace {
	using namespace mlir;
class MyOpConversion : public OpConversionPattern<shape::ConstShapeOp> {
public:
  using OpConversionPattern<shape::ConstShapeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(shape::ConstShapeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override{
					Value op1 = rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getI32IntegerAttr(1));
					Value op2 = rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getI32IntegerAttr(2));
					auto res = rewriter.replaceOpWithNewOp<arith::AddIOp>(op, op1, op2);
					rewriter.replaceAllUsesWith(op, res);
					op->getBlock()->dump();
					return success();
				  }
};

Here is my error log


  * Pattern : 'shape.const_shape -> ()' {
Trying to match "(anonymous namespace)::MyOpConversion"
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::detail::ConstantOpGenericAdaptorBase::Properties)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface::Trait<Empty>)
    ** Insert  : 'arith.constant'(0x55b9d1c910d0)
    ** Insert  : 'arith.constant'(0x55b9d1c91140)
    ** Insert  : 'arith.addi'(0x55b9d1c89760)
    ** Replace : 'shape.const_shape'(0x55b9d1c789e0)
mlir-asm-printer: Verifying operation: builtin.module
type of return operand 0 ('i32') doesn't match function result type ('tensor<3xindex>') in function @const_shape
mlir-asm-printer: 'builtin.module' failed to verify and will be printed in generic form
^bb0:
  %0 = "arith.constant"() <{value = 1 : i32}> : () -> i32
  %1 = "arith.constant"() <{value = 2 : i32}> : () -> i32
  %2 = "arith.addi"(%0, %1) : (i32, i32) -> i32
  %3 = "shape.const_shape"() <{shape = dense<[1, 2, 3]> : tensor<3xindex>}> : () -> tensor<3xindex>
  "func.return"(%2) : (i32) -> ()
"(anonymous namespace)::MyOpConversion" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'func.return'(0x55b9d1c8bff0) {
      "func.return"(%2) : (i32) -> ()

      * Fold {
xxx
      } -> FAILURE : unable to fold
    } -> FAILURE : no matched legalization pattern
    //===-------------------------------------------===//
  } -> FAILURE : failed to legalize operation updated in-place 'func.return'
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x55b9d1c8bff0) {
  "func.return"(%0) : (tensor<3xindex>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

In this scenario, it seems more appropriate to write a pass to address this issue, rather than using conversion.

During the conversion process, we need to ensure the legality of the transformation. Generally, we do not modify the operand types and return value types of the op, except using a type converter for type conversion and some other specific scenarios. Otherwise, it may cause other op to fail with verification.

In your case, you need to modify the result type of the function from tensor<3xindex> to i32, which is clearly not a process that can be modeled by a type converter. Therefore, my suggestion is to add a pass to implement the modification of the function signature. If you need to use conversion to handle modifications to other ops, you should ensure consistency in the function’s return type during this process to make verification happy. You can use unrealized conversion or similar ops to maintain compatibility.

Thanks @cxy . I have figure out the way to solve it which is exactly what you have adviced. I’d like to add more detail about how to achieve this.
First, you need to write a conversion on func.FuncOp to change it’s signature. You need to change the FunctionType property of func.FuncOp and the input/output arugment of func.FuncOp’s block to ensure the consistency.
Then, you need to use addDynamicLegalOp on func.FuncOp to ensure that funcOp is legal only if the returnOp matches the return type. So anything break this constraint will trigger the conversion for FuncOp.
Finally, put them together into patterns in one pass and see magic happens :slight_smile:

I’m not sure I follow: dialect conversion is organized around the TypeConverter which is meant to change the types.
The usual approach is to ensure that a TypeConverter is used consistently for every type change during the conversion.

Why not: if the type converter is setup to map tensor<3xindex> to i32?

I’m sorry for not make myself clear. Set type conveter with map tensor<3xindex> to i32 can correctly replace the function signature. If all tensor<3xindex> should converted to i32, use type converter is the best solution in this case. :slightly_smiling_face:

In this specific example, it seems that the modification to the function parameter signature is based on conventions rather than through type conversion. Firstly, other conversions do not guarantee consistency in return value types. Additionally, convert tensor<3xindex> into i32 doesn’t appear to be a universally apple-to-apple convert throughout the entire program. For such rule-based convention transformations, it is recommended to use pass instead of conversion. A typical scenario similar to this is where the createBufferResultsToOutParamsPass passes the function’s return value as a parameter during transformation.

1 Like