[PSA] Verification of SameOperandsAndResultType is being tightened

Note: If you use SameOperandsAndResultType and expect that it means “pointer identical types”, you don’t really need to keep reading. This shouldn’t affect you.

For those that don’t know, we currently use and treat the trait SameOperandsAndResultType differently than it is actually implemented. We have many uses (both in ODS and in all of our upstream operations) that assume that SameOperandsAndResultType means that the types are pointer-identical, i.e. literally the “same” type. The trait implementation, however, allows for shaped types to be “compatible”, meaning that they can differ in some cases (e.g. tensor<1xf32> is “compatible” with tensor<*xf32>). This makes sense if your IR is a tensor/shaped based IR that allows for shape mismatches in certain cases, but is quite problematic for a general purpose trait.

Given our mixed interpretation/implementation of the trait, this means that if used in certain ways an operation could:

  • adopt invariants it doesn’t expect (every upstream op essentially, and likely many downstream)

  • lose information

    • The assembly format assumes that “same” means “pointer-identical”, so an assembly format of
      operands : type($result) could actually be incorrect.
  • Crash/break/rm *.*/etc.

    • Given some things treat the trait differently, an operation could encounter situations that it doesn’t expect. I’m not saying things right now would crash/break/explode, but such differences in an APIs expectation/implementation usually invite those things.

I sent out a commit that rectifies this by tightening the verification to be pointer identical. If users actually rely on this, they can get close to the same behavior by using a combination of InferTypeOpInterface+SameOperandsAndResultShape+SameOperandsAndResultElementType.

– River

We have many users of SameOperandsAndResultType who really want CompatibleOperandsAndResultShape+SameOperandsAndResultElementType+InferTypeOpInterface as you noticed. I think we should have a trait in MLIR that actually provides this, in particular the “result type inference” isn’t necessarily obvious to implement and it would be nice to have a consistent behavior across clients on this. That is, assuming we find a reasonable default for result type inference: I’ve been struggling a bit when you include in the mix the possibility for tensor encoding (for sparse) and quantization element types.

Assuming you are referring to TensorFlow+friends(the original cause of this) here? Given the mixed interpretation and semantics surrounding the current SameOperandsAndResultType, I’ll have to defer to you to see what is actually desired here from your point-of-view. Happy to have a trait for this, but right now the full desired semantics aren’t clear to me.

– River

1 Like

Right: this is TensorFlow but also XLA and other internal ML project.
I think we should just try to take a stance in MLIR with respect to “tensor programs” and type inference in the context of things like encoding and quantization. @stellaraccident may have good insights on this as well.

Do we have any upstream ops that actually want the old behavior of this trait that we could use to guide a batteries-included replacement upstream? (TOSA?)

Speaking for Torch-MLIR, I haven’t found such a trait very useful. It seems to be generally subsumed by

  1. A more complete shape refinement system (In a sense, what we are discussing here is a very limited / partial specification of a shape transfer function)
  2. An “AllowsTypeRefinement” trait which allows us to know when we can safely update the type of an op in place (e.g. during shape refinement) or whether we need to cast back to the original type.
  3. There aren’t actually very many ops that have this behavior (essentially only elementwise and some broadcast-like ops maybe?). And the check that the trait was doing was very weak – it only checks that the first result is compatible with the other results and operands. Since “is compatible” is not transitive, that means it would allow things that probably are not intended such as
my_elementwise_add %0, %1 : (tensor<7xf32>, tensor<4xf32>) -> tensor<?xf32>

This would fail the trait I believe, but you meant maybe something like the following instead?

my_elementwise_add %0, %1 : (tensor<?xf32>, tensor<4xf32>) -> tensor<7xf32>

The first operand type would be picked and compared to the second operand type, and to the result. It is compatible with both even though the second operand type is incompatible with the result.

The “reference” type is the first result:

  auto type = op->getResult(0).getType();
  auto elementType = getElementTypeOrSelf(type);
  for (auto resultType : llvm::drop_begin(op->getResultTypes())) {
    if (getElementTypeOrSelf(resultType) != elementType ||
        failed(verifyCompatibleShape(resultType, type)))
      return op->emitOpError()
             << "requires the same type for all operands and results";
  }
  for (auto opType : op->getOperandTypes()) {
    if (getElementTypeOrSelf(opType) != elementType ||
        failed(verifyCompatibleShape(opType, type)))
      return op->emitOpError()
             << "requires the same type for all operands and results";
  }
  return success();

But yeah, the big picture is that because compatibility isn’t transitive, this trait ends up checking something unexpected.

:man_facepalming: … I thought it was the first argument, your example was the right one!

Just for background, we want to allow things like

 %0 = some-op(%arg0, %arg1) : (tensor<10x20xf32, #CSR>,
                               tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #CSR>

under some sort of trait that says SameOperandsAndResultTypeButIgnoringSparsityEncoding. Such operations would basically tell the compiler to apply an operation to sparse tensors in, possibly different formats, and generate the resulting sparse tensor in an again possibly different format, but under the same type rules as would apply in the dense case.

  1. Re: expressive power. Similarly, for quantization, we want to allow things like:
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<10x20x!quant.uniform<i8:f32, 1.0:15>,
                                 tensor<10x20x!quant.uniform<i8:f32, 2.0:16>) ->
                                 tensor<10x20x!quant.uniform<i8:f32, 3.0:17>,  

So it would be good to have a solution that generalizes to multiple similar features like quantization and sparsity (and, depending on the outcome of the discussion about dynamism, to dynamism as well).

In MHLO, we’re thinking about replacing [SameOperandsAndResultType] with something like [HLO_ CompatibleOperandsAndResultShape] that would compare operand and result types with each other using the compatibility rules that account for dynamism, quantization and sparsity.

  1. Re: type inference. As Mehdi imentioned, quantization and sparsity have tricky interactions with type inference, given that e.g. dense tensors can be added together and produce a sparse tensor. Similarly, quantized tensors can be added together and produce a non-quantized tensor.

I think that from that it follows that implementing InferTypeOpInterface for ops like that is out of the question, because inferReturnTypes can now sometimes have several (or a ton of, if quantization is involved) potential results.

However, I think that InferShapedTypeOpInterface is still a thing for these ops, given that neither quantization nor sparsity affect shapes, and that inferReturnTypeComponents says “Unknown (e.g., unranked) shape and nullptrs for element type and attribute may be returned by this function while returning success”.

1 Like

This is just a gap in implementation rather than something fundamental it would seem: this should be a meet performed on all the types. Where meet fails to produce valid result, verification should fail. One could then allow introducing meet configuration too (which fits in with type “hierarchy” below), but that’s separate issue.

These are quite different. For encoding the expectations around them are a bit fuzzy. They are currently in a state similar to unregistered attributes in TF where it is sort of in flux what they are allowed to represent, and load-bearing/can be skipped. So currently it is unclear whether safe to skip checking tensor encoding or not, one could slap in list (in some way) of encoding which are valid to skip while others are load bearing [although perhaps this shows we need iteration on the design].

Quantization is actually a different element type for many of these, so there one would do same shape trait most likely and be done as there is no equality of the type really.

Of course both could be represented by changing to compatible trait. And one could even introduce type hierarchies and refinement lattice. That could be generally useful.

Which InferReturnTypeOpInterface/InferShapedTypeOpInterface also allows in query in that if all consumers of an op’s return type can be inferred, then the op being considered result’s type may be changed as shape inference/type inference would be able to propagate. Or do you mean in the more general frontend case where consumers often don’t care if shape is exact in their checks and so it can be refined without propagation (e.g., in TF function calls can accept args with less or more refined shapes).

The issue with quantized and sparsitity is that it isn’t uniquely determined: there is no inference that can produce it. Hence the compatibility function on the interface (which is what we use with subtype inference in TF too where we hit something similar). Return type inference returns the most exact type that it can given inputs & op attributes & regions, that is used when the user doesn’t specify result type and is checked to be compatible using user provided function. Quantization and sparsity would need to be attributes of the op to enable inference, but currently they are captured in the type by some process which has additional info/heuristics but nothing in the op itself or deducible from the ops constrains by the input-output relationships of the op uniquely. It means no inference can work for these as these actually represent additional constraints not captured in the op as defined. So if you really wanted it to be inferred, you’d have to expand the op definition. But you could still use this op interface, it would still work for the general case, enable simple builders, do verification etc. I think the result today is actually one of the valid result types, it may not be what the user wants but it is a member of.

1 Like

Oh and yes this should be changed to OpTraitList really, didn’t get to it yet but that was part of reason why it was moved.

1 Like

Forgot to mention: this is also how layouts work in XLA. Layout assignment & propagation is a pass, layout is part of shape there (which is more akin to ranked tensor type here), XlaBuilder doesn’t try and use layout (https://github.com/tensorflow/tensorflow/blob/bd98b8be634287f0074472cce87d892c0f9c715c/tensorflow/compiler/xla/client/xla_builder.cc#L700) similar to build method generated using infer return type op interface shouldn’t. There there are multiple variants of Equal (https://github.com/tensorflow/tensorflow/blob/bd98b8be634287f0074472cce87d892c0f9c715c/tensorflow/compiler/xla/shape.cc#L155), Compatible (https://github.com/tensorflow/tensorflow/blob/bd98b8be634287f0074472cce87d892c0f9c715c/tensorflow/compiler/xla/shape_util.cc#L722) and Same’ness.

It is fair that Same and Equal are overloaded and many would think exactly equal. And we should revise old code that is confusing :slight_smile: Goal with moving and (in future) changing to OpTraitList is to make this less special and less hardcoded. This doesn’t need to be as privileged.

1 Like

Right: this is what I meant to refer to earlier with " with respect to “tensor programs” and type inference in the context of things like encoding and quantization", right now I don’t think we have implemented anywhere this aspect of “returns the most exact type that it can”.

I’m not sure I have a strong opinion on how to generalize this. This always felt like one of those things to me that came from the “early days” when MLIR was an experimental mapping of Tensorflow, et al. If we were doing this today, we would likely be conservative and:

  • Push this to the dialects if it was primarily a vehicle for verification and/or tensor op building.
  • Bias against a core trait unless if it is also useful for systematic transformations in some way.
  • Look at interfaces, which were not a thing when most of these domain specific traits were made.

I tend to agree with Sean – I’m not aware of a lot of utility coming from this somewhat hard to express definition when it comes to actual lowerings.

Since this feels like a “mid level frontend thing” I would look at TOSA, ATen, and XLA to see if there is a benefit to something common. Totally supportive of the weird in the downstreams and pushing bounds there, but in general, I think that complicated definitions probably live closer to the using project, and we offer a lot of tools to express them.

This is a pretty weak opinion/judgment call. I think I agree with river’s description at the top.

3 Likes