Allow shape concretization (or type concretization) in rewrites

When folding or canonicalizing, there is often the opportunity to concretize a type. For example, the extent tensors in the shape dialect are usually dynamically sized (tensor<?xindex>) but, when folding and canonicalizing them, the rank of the shape may eventually be known, leading to a statically shaped extent tensor.

In the following example, the shape is constant and could be folded to const_shape [1, 2, 3] : tensor<3xindex>, however, the subsequent operation does not accept such a type concretisation.

%0 = ... : tensor<1x2x3xf32>
%1 = shape_of %0 : tensor<?xindex>
%2 = memref.buffer_cast %1 : memref<?xindex>

In the offline discussion with @herhut, we found at least three ways to deal with this problem:

  1. Generally allow more concrete types, i.e., every op that accepts tensor<?xindex> operands must also accept tensor<3xindex> operands. This requires a general consensus on concreteness which is quite obvious for the shape lattice but may be controversial elsewhere.

  2. Enforce strict type equality in rewrites and make it the rewrite patterns’ obligation to insert casts where needed. This can make pattern matching quite fiddly with all the explicit casts in the IR. In the end, this requires many canonicalizations that are there for the sole purpose of eliminating redundant casts where ops actually accept more concrete types.

  3. Allow rewrite patterns to change types and have a mechanism that inserts casts for consumer ops that do not accept the new type. This would require op verification that supports operand substitution, not to create the new op just to check if it is valid. This is essentially the case (1) with a fallback mechanism.

I don’t think that 1. is possible. Moving away from pointer-equality for types here seems like it would have lots of ramifications. E.g. std.return would have to know about it. Furthermore, IREE has ops where the number of operands is dependent on the tensor type. E.g. for this op if there is an operand of type tensor<3x?x4xf32> then it implies (and is checked by the verifier) that there is one additional index argument giving the size of the ?. If you change the type to tensor<3x7x4xf32> then that op becomes invalid and crashes the compiler.

Number 2. is effectively what we do today. I haven’t found it very problematic, maybe a few more motivating cases would be useful to get a feeling for the issues here.

Number 3. is interesting. To move towards modeling this better and making that possible, one design is to introduce two things:

  1. Trait: AllowsTypeRefinement that indicates that an op allows operand and result types to be refined.
  2. Type interface: Refinable which provides a mechanism for
    a. Checking if a type is a refinement of the given type
    b. Inserting a potentially type-specific “derefine cast” which derefines the type back to the original type. (we might want a builtin op for this to handle the common cases, like we have for unrealized_conversion_cast).

We could then move towards core infra better supporting the various use cases, such as by teaching the pattern rewriter to automatically insert derefinement casts unless all users of a value have AllowsTypeRefinement.

Speaking for npcomp, we could use this in two places:

A. TorchScript’s IR allows a more refined type anywhere a less refined type is expected. Today, we model that with special “torch.derefine” ops which derefine the ops to create pointer-equal types. (torch.derefine op, test showing situations it is needed)

B. In our RefineTypes pass (which mainly does shape inference on tensors), we have a hardcoded check for whether a value allows its types to be refined (e.g. aten.mul allows tensor<?xf32> to be refined to tensor<4xf32>) – if it’s not possible, we need to insert casts. See the hardcoded function and especially the TODO inside that function laying out the concerns.

This (removing redundant casts for inferable ops) is what we mostly do currently in the shape inference pass TF side.

Indeed same TF side (check on dialect, whether implements type inference trait or whether uses trait such as same element and result type from which result type can be inferred, e.g., allow changing type if the consumer can be refined, else add a cast).

An issue is that even knowing in general which case is needed requires a query (currently TF side this is based purely on types, but it assumes it knows the types interacting which may not be true in general).

I think this goes to the above, if one implements type inference trait then one can refine, as I think one also needs to be able to refine. There is a question of interface asking if op allows “compatible” types (so like we have for return type but for input types too) so that it doesn’t end up with an invalid IR post just changing operand type.