Issues when playing with AnyTypeOf in Op definition

I have defined a custom op like this:

def MyExampleOp : FooOp<"dummy"> {
  let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
  let results = (outs AnyTypeOf<[F64Tensor, F64MemRef]>);
  ...
}

In my input MLIR file, I am using the op like this:

%0 = foo.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = foo.dummy(%0 : tensor<2x3xf64>) to tensor<3x2xf64>

I’ve written a lowering pass that converts this MLIR into affine, memref and std dialects (similar to toy example). What I want to do is to keep foo.dummy op untouched after a first lowering, instead of lowering it (because I wanna do that in a second lowering). As expected, my foo.constant gets lowered to memref dialect, so the foo.dummy has to change the input type from tensor to memref. That’s great, and that is why I used AnyTypeOf<[F64Tensor, F64MemRef]> in the arguments. However, the output type is still tensor (I would like it to be memref too!). In other words, after lowering the MLIR, I get:

...
%0 = memref.alloc() : memref<2x3xf64>
...
%1 = foo.dummy(%0 : memref<2x3xf64>) to tensor<3x2xf64>

but I was expecting to get:

...
%0 = memref.alloc() : memref<2x3xf64>
...
%1 = foo.dummy(%0 : memref<2x3xf64>) to memref<3x2xf64>

How can I achieve that? I also set the results with AnyTypeOf to let MLIR know that memref is valid too, but it always chooses tensor. I tried changing the order of the AnyTypeOf, but still the same result.

This would today require two parts 1) implementing the type inference trait for the op, 2) running type refinement post. We don’t yet have a generic type refinement pass upstream (probably 4 weeks given other constraints). The latter is required as one would need to propagate the changed type further. (If the op has constraint that input type matches output type then that simple case is generated today, so 1 could be handled by adding that constraint)

A quick term solution is to manually refine during lowering, but you might hit this a few times.

Right now there’s nothing that’s telling the system that it should modify the return type of your example op. It sounds like Jacques has got some forward looking generic type refinement options, but for now I think you can just do this with dialect conversion. If you’re expecting there to be no tensors, you can make op legality dynamic based on a TypeConverter (Dialect Conversion - MLIR). Something like this one in SCF (just grabbed a random example). Then you can also add a conversion pattern that matches your example op. Something like (untested):

class ConvertExampleOpTypes : public OpConversionPattern<ExampleOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(ExampleOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    Type newType = typeConverter->convertType(op.getType());
    rewriter.replaceOpWithNewOp<ExampleOp>(op, newType, operands);
    return success();
  }
};
1 Like

Thank you very much for your suggestions! I really appreciate your comment, @gcmn. I’m a MLIR newbie but your instructions were really clear and your suggestion works like a charm.

One additional question. Sometimes I’m a bit lost in MLIR because I don’t know where certain functions come from. For example, in the case of op.getType() it looks like a pretty interesting function. (I guess that) ExampleOp is an instance of Op (MLIR: mlir::Op< ConcreteType, Traits > Class Template Reference), but in the doc I can’t see any reference to the getType() method. Why? Where can I check the documentation for this method?

getType() is usually exposed on Value (the individual SSA values) and not on operations. It is always possible for a specific operation class to add method and other accessors, sometimes inherited by Traits.

In this particular case, you get it because the op has one result so you implicitly get the OneResult and OneTypedResult traits.

As to the general question, I rely on IDE autocompletion and then taking me to the definition of things, but I still spend a lot of time manually tracing the class hierarchy in MLIR. There’s quite a lot of cleverness (e.g. CRTP to avoid virtual methods, ops are just smart pointers) in there that has important implications for performance, but I think it harms discoverability quite a bit. I’d love to see some tool that could flatten out the entire set of methods available on a class so I could see them all at once as well as click through to where they were defined. I think the worst offender in this is OpBuilder::create I think, which uses std::forward that completely defeats any IDE I’ve used. You basically have to know that it’s calling OpTy::build under the hood and go look at the definitions of build and then figure out the write calling pattern (and then let C++ error messages push you in the right direction when you screw up). I have spent like half an hour writing a single call to create before.

Yes! I spent a lot of time a few days ago figuring out that OpTy::build thing. I use gdb to understand how the execution flow works. Sometimes is a bit limited but…

If you guys still have problems tracing the class hierarchy, then I feel a bit more relieved when I’m struggling to do so. Thanks for your help!