RFC: Allow shape dynamicism in SCF

While tensor types support dynamic shapes (dynamic dimensions and unranked), the interfaces defined in ControlFlowInterfaces.* (used for example by SCF) impose type equality constraints along control-flow edges.

So for example,

func @test(%condition: i1,
           %then_value : tensor<?xi32>,
           %else_value : tensor<?xi32>) {
  %res = scf.if %condition -> (tensor<?xi32>) {
    scf.yield %then_value : tensor<?xi32>
  } else {
    scf.yield %else_value : tensor<?xi32>
  }
  return
}

is legal. But

func @test(%condition: i1,
           %then_value : tensor<1xi32>,
           %else_value : tensor<2xi32>) {
  %res = scf.if %condition -> (tensor<?xi32>) {
    scf.yield %then_value : tensor<1xi32>
  } else {
    scf.yield %else_value : tensor<2xi32>
  }
  return
}

is not.

This prevents dialects that want to allow dynamic shapes to reuse those interfaces.

There is a question here, as to the meaning of dynamic dimensions and unranked tensors. Do dynamic dimensions and unranked shapes mean “fixed but unknown”, or can they be dynamic at runtime ? My understanding is MLIR does not want to prevent runtime dynamicism of the shapes.

So the proposal is to update the type constraints in control-flow interfaces to allow such cases.

From local work, this should be achievable with very little code. (But a significant amount of tests !)
In particular, we would define join and meet functions for builtin types (with almost all of the logic dealing with shapes). Once these are available, the constraints in ControlFlowInterfaces.* become simple checks with join and meet.

Does that sound reasonable ?
If so, I’ll clean code locally and upload things for review.

1 Like

Thanks for the RFC/interest.

In general, the yielded types must match the declared return type of the outer if operation, and part of the process of lowering into something like SCF (which is a relatively low level dialect and has invariants amenable to code generation) must ensure this either directly or via some kind of shape refinement/casting mechanism.

In related projects that are lowering from unconstrained frontend systems, this is accomplished by adding casts which interop with type refinement. See for example the numpy.static_info_cast op. Something algorithmically needs to decide how much to refine or de-refine such decisions. I’ve also done this before by propagating special unknown types during program specification and then using a type refinement algorithm to fully type everything. For code generation, it is typically important to be precise and consistent on the data flow for how dynamic tensors are.

It’d be good to hear from others wrt whether I am viewing this as too constrained. Also, it is completely valid to define higher level control flow operations which have whatever constraints are appropriate for your problem area.

We don’t want to prevent “dynamicism” indeed, but I see the question differently here: it is about how implicit vs explicit we should be in the IR about the type.
Right now the core dialects are leaning toward being explicit about type mismatch, for example you would need to:

func @test(%condition: i1,
           %then_value : tensor<1xi32>,
           %else_value : tensor<2xi32>) {
  %res = scf.if %condition -> (tensor<?xi32>) {
    %dyn_then_value = tensor.cast %then_value : tensor<1xf32> to tensor<?xf32>
    scf.yield % dyn_then_value : tensor<?xi32>
  } else {
    %dyn_else_value = tensor.cast %else_value : tensor<2xf32> to tensor<?xf32>
    scf.yield %else_value : tensor<?xi32>
  }
  return

There is no loss of expressiveness (as far as I know), but you have to handle such “reinterpret_cast” explicitly.

Some more I should add: while the dialects upstream (including scf) are very explicit and don’t prevent from lowering a higher level implicit representation by adding the explicit constructs as needed, it is true that the interface implementation could limit the implementation of more implicit conversion in the control flow implemented by external dialect, and that would be unfortunate.

Right now I’m not sure where this assumption is baked though, but I can see transformations not being ready to handle this correctly, because of the need to potentially insert cast when performing code motion and the type would change.
Right now the perfect type equality is checked here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Interfaces/ControlFlowInterfaces.cpp#L127-L132 ; but this has to be called manually I think, there is no guarantee that every operation that implement the interface would do it (unless I missed a code path?).

Processing your replies.
The check happens via verify for RegionBranchOpInterface.

Ok. Of course for a higher-level dialect the change in the constraint would look cleaner. But the cast solution does make sense if we want to preserve the type equality constraints in SCF. I am not yet sure how annoying dealing with the cast would be, but it should be ok. I’ll come back to comment here if anything interesting comes up.

Is there any interest in taking join and meet functions for the set of builtin types ?
I have found them to be useful when dealing with type inference and co.

A sample of the declarations:

/// The join function for the set of builtin types, and the partial order "less
/// specialized than or equal", noted `≤`.
/// It returns the most specialized type (if it exists), that is less
/// specialized than both `ty1` and `ty2`.
/// The join `j` of `ty1` and `ty2` is such that:
/// * j ≤ ty1, and j ≤ ty2
/// * For any type t such that t ≤ ty1 and t ≤ ty2, t ≤ j.
/// For example:
///  ty1               | ty2               | ty1 v ty2
///  i8                | i8                | i8
///  i8                | i32               | <none> (null type)
///  tensor<1xf32>     | tensor<?xf32>     | tensor<?xf32>
///  tensor<1x2x?xf32> | tensor<1x?x3xf32> | tensor<1x?x?xf32>
///  tensor<4x5xf32>   | tensor<6xf32>     | tensor<*xf32>
///  tensor<1xi32>     | i32               | <none> (null type)
///  tensor<1xi32>     | tensor<i32>       | tensor<*xi32>
Type join(Type ty1, Type ty2, Optional<Location> location = None);

/// More doc.
Type meet(Type ty1, Type ty2, Optional<Location> location = None);

/// Indicates whether `ty1` is compatible with `ty2`, and less specialized than
/// `ty2`.
inline bool isLessSpecialized(Type ty1, Type ty2) {
  return join(ty1, ty2) == ty2;
}

/// Indicates whether `ty1` is compatible with `ty2`, and more specialized than
/// `ty2`.
inline bool isMoreSpecialized(Type ty1, Type ty2) {
  return meet(ty1, ty2) == ty1;
}

Do you have a code pointer? I see:

    static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) {
      static_assert(!ConcreteOpType::template hasTrait<OpTrait::ZeroRegion>(),
                  "expected operation to have non-zero regions");
    return success();
    }

On RegionBranchOpInterfaceTrait.

https://github.com/llvm/llvm-project/blob/809435e390e91355f64bee0142a65c4fe6e9f488/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td#L172

Right but verifyTypes isn’t called automatically by operations implementing this interface, that’s what I referred to in my earlier post as “this has to be called manually I think”.
I don’t know what the intent with this method is or why it wasn’t hooked to the verifier entirely.

I am pretty sure it is called automatically, as I see it fire via verification. Checking.

We had a recent thread about this same situation. I sketched some ideas forward in this thread along with my use cases for this: Allow shape concretization (or type concretization) in rewrites - #2 by _sean_silva

Changing SCF for this seems undesirable (that’s my gut). But yeah, higher-level dialects should be able to define their own, and we should definitely think about how RegionBranchOpInterface (and possibly BranchOpInterface) can be adjusted to allow this (or add a new interface), as it’s otherwise way too tedious to have your own control flow ops/verification.

Having join/meet functions for builtin types seems useful!!! If you look at the post I linked, I describe a
“Refinable” type interface that encodes subtyping relationships that would allow implementing join/meet functions. The other thing we need is the way to upcast/downcast through the type lattice (which the “Refinable” interface could also provide). That would allow code to generically insert the necessary casts to make types agree at “exact type equality” places.

btw, I did start to implement in npcomp the AllowsTypeRefinement trait and it works really well for my use cases there: https://github.com/llvm/mlir-npcomp/pull/212/commits/de4fe59b580e6abf26a54ec1a97a97ccaaa55e69

All hail the backtrace…

(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.11
  * frame #0: 0x0000000102b759f3 mlir-opt`mlir::Operation::emitError(this=0x0000000123e08c70, message=0x00000001223ffae8) at Operation.cpp:286:3
    frame #1: 0x0000000102b752c3 mlir-opt`mlir::Operation::emitOpError(this=0x0000000123e08c70, message=0x00000001223ffc80) at Operation.cpp:581:10
    frame #2: 0x000000010669f1a5 mlir-opt`verifyTypesAlongAllEdges(op=0x0000000123e08c70, sourceNo=Optional<unsigned int> @ 0x00000001223ffea8, getInputsTypesForRegion=mlir::function_ref<Optional<mlir::TypeRange> (Optional<unsigned int>)> @ 0x00000001223ffe98)>) at ControlFlowInterfaces.cpp:139:39
    frame #3: 0x000000010669ebc5 mlir-opt`mlir::detail::verifyTypesAlongControlFlowEdges(op=0x0000000123e08c70) at ControlFlowInterfaces.cpp:214:16
    frame #4: 0x00000001014fe8c5 mlir-opt`mlir::RegionBranchOpInterface::verifyTypes(op=0x0000000123e08c70) at ControlFlowInterfaces.h.inc:175:14
    frame #5: 0x00000001014f2fb8 mlir-opt`verify(op=IfOp @ 0x0000000122400348) at SCF.cpp:901:10
    frame #6: 0x00000001014f2eab mlir-opt`mlir::scf::IfOp::verify(this=0x00000001224006e8) at SCFOps.cpp.inc:688:10
    frame #7: 0x00000001015063d4 mlir-opt`mlir::Op<mlir::scf::IfOp, mlir::OpTrait::NRegions<2u>::Impl, mlir::OpTrait::VariadicResults, mlir::OpTrait::ZeroSuccessor, mlir::OpTrait::OneOperand, mlir::RegionBranchOpInterface::Trait, mlir::OpTrait::SingleBlockImplicitTerminator<mlir::scf::YieldOp>::Impl, mlir::OpTrait::HasRecursiveSideEffects, mlir::OpTrait::NoRegionArguments>::verifyInvariants(op=0x0000000123e08c70) at OpDefinition.h:1774:39
    frame #8: 0x0000000102bbc7d1 mlir-opt`mlir::AbstractOperation::verifyInvariants(this=0x0000000123e07ea8, op=0x0000000123e08c70) const at OperationSupport.h:103:12
    frame #9: 0x0000000102bbc250 mlir-opt`(anonymous namespace)::OperationVerifier::verifyOperation(this=0x0000000122402438, op=0x0000000123e08c70) at Verifier.cpp:191:32

But yeah, higher-level dialects should be able to define their own, and we should definitely think about how RegionBranchOpInterface (and possibly BranchOpInterface) can be adjusted to allow this (or add a new interface), as it’s otherwise way too tedious to have your own control flow ops/verification.

That would be convenient.

Having join/meet functions for builtin types seems useful!!! If you look at the post I linked

The join/meet functions I implemented only supported the builtin types, with no way to “plug in” support for user types. I’ll check out your thread, and see what I can contribute.

Right, this is a manual call from the verifier of scf::IfOp to this method: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/SCF/SCF.cpp#L901 ; it isn’t injected by the interface verifier on ops that implement it (even if that may have been the intent?).

I was thinking in terms of “called when verifying SCF”, but you are absolutely right.

I’d like to get to the point where we also allow:

  1. an unknown type, that could be later inferred to a known type like ‘i32’ or ‘tensor<*xi32>’ This would be useful to represent both truly dynamic languages and things like c++ template instantiation.

  2. Parameterized types, like tensor so we can typecheck things that have a compile time unknown but runtime constant size like matrix multiply of tensor * tensor . This is quite different from the current tensor<?x?xi32> that has a runtime variable size.

I think 1. can already be achieved by using a custom type. You would still need to cast from and to this type (unbox the dynamic value) if you want to actually compute on it with dialects like math. And the same cast would be used to ensure the type equality constraints that scf has.

I also find 2. appealing and maybe a custom type would be the right way to go there, as well. The main question I had was how expressive the type system should be. We have an analysis that essentially derives these parameterized types (in a side data structure) and we went with affine expressions in its “type language” for now.

I also do not have a good intuition how well such a type system would compose with transformations. The advantage of an analysis is that we can just invalidate it after IR rewrites. If you put this into the type system, you have to make sure rewrites preserve the information. For higher level rewrites I can imagine this but at lower levels this seems difficult. Do you have an intuition on that?

1 reminds of the type in TFG and what is in npcomp too, effectively a tensor<* x ?>.

This only says compile time unknown, it doesn’t say runtime variable, it just says i don’t know at compile time.

Could you expand on 2 a bit? Are you asking for a value-dependant type?

  1. doesn’t have to be a tensor. it could turn into ‘int’ or ‘float’, for instance.

The point of 2 is to be able to typecheck something like a call to matmul without specializing it. For instance, I have defined matmul(tensor<?x?xf32>, tensor<?x?xf32>) and I call it from a bunch of places in my code. I want to have one version of the matmul code that has the sizes as runtime variables. Today in MLIR, we’d cast away the size information at each call site and then presumably catch the error dynamically inside matmul at runtime. I’d like to have richer types that would allow catching this error statically at the call site with only the type signature information:

tensor<7x4xi32> x;
matmul(x,x) // Error should have been matmul(x, transpose(x))

This isn’t all about the type though, this rule can be implemented with what we have I think, but this is just not possible with the standard call operation.