[RFC] Symbolic Shape Analysis


During a dynamic shape code generation, we always want to find the relationships among the dynamic shapes. The relationships include and not limited to:

  1. Where’s the dimension’s dynamic comes from? For the example below, we’d like to deduce that all of %0, %1, %3’s dynamic is from %arg0.
func.func @several_ops(%arg0: tensor<?x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) -> tensor<?x4xf32> {
  %0 = "mhlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<
      lhs_contracting_dimensions = [1],
      rhs_contracting_dimensions = [0]
   precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
  } : (tensor<?x4xf32>, tensor<4x4xf32>) -> tensor<?x4xf32>
  %1 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
  %2 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<2xindex>) -> tensor<?x4xf32>
  %3 = mhlo.add %0, %2 : tensor<?x4xf32>
  return %3 : tensor<?x4xf32>
  1. The constraints on the dynamic shapes. Let’s say dimension a and dimension b are both dynamic. We’d like to deduce that whether a equals b, or a + b equals a static dimension, etc.

To address these problems, I’m considering to implement a generic symbolic shape analysis.

Related Work

I found an interesting pass in mlir-hlo repo. It uses affine expression to describe how a dynamic dimension could be calculated. But I’m afraid we’ll do a lot of repetitive work if symbolic shape analysis is based on affine expresion. Since:

  1. We cannot reuse the reifyReturnTypeShapes method in InferShapedTypeOpInterface
  2. If we want to do some simplification or do more analysis on the symbolic expression, we could not reuse the existing pass(Ex. cse, canonicalize) or write new passes based on the MLIR pass framework.

Proposed Solution

Here I drafted a new solution to express symbolic shape calculation. It works in 3 steps:

  1. Create an auxiliary symbolic shape inference function for each original function.
  2. Run shape-reification pass on the created auxiliary functions.
  3. Analyze based on the auxiliary functions.

We still use the IR mentioned in background section as an input IR. After step 1, an auxiliary function will be created as below. Where the arguments of the auxiliary function is the same as the original one, and the results includes two parts. In the example there are 4 intermediate results, %0, %1, %2, %3. Then the return results will be [shape_of %0, shape_of %1, shape_of %2, shape_of %3, %0, %1, %2, %3].

func.func private @_shape_infer_several_ops(%arg0: tensor<?x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>) attributes {auxiliary_of = "several_ops"} {
  %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<?x4xf32>, tensor<4x4xf32>) -> tensor<?x4xf32>
  %1 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
  %2 = shape.value_as_shape %1 : tensor<2xindex> -> !shape.shape
  %3 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
  %4 = shape.shape_of %3 : tensor<2xindex> -> tensor<1xindex>
  %5 = shape.value_as_shape %4 : tensor<1xindex> -> !shape.shape
  %6 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<2xindex>) -> tensor<?x4xf32>
  %7 = shape.shape_of %6 : tensor<?x4xf32> -> tensor<2xindex>
  %8 = shape.value_as_shape %7 : tensor<2xindex> -> !shape.shape
  %9 = mhlo.add %0, %6 : tensor<?x4xf32>
  %10 = shape.shape_of %9 : tensor<?x4xf32> -> tensor<2xindex>
  %11 = shape.value_as_shape %10 : tensor<2xindex> -> !shape.shape
  return %2, %5, %8, %11, %0, %3, %6, %9 : !shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>

And after step 2, a shape-reification pass and a cse pass will be run on the auxilary function, producing the following IR:

func.func private @_shape_infer_several_ops(%arg0: tensor<?x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>) attributes {auxiliary_of = "several_ops"} {
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %0 = shape.const_shape [2] : tensor<1xindex>
  %1 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<?x4xf32>, tensor<4x4xf32>) -> tensor<?x4xf32>
  %2 = tensor.dim %arg0, %c0 : tensor<?x4xf32>
  %3 = tensor.from_elements %2, %c4 : tensor<2xindex>
  %4 = shape.value_as_shape %3 : tensor<2xindex> -> !shape.shape
  %5 = shape.value_as_shape %0 : tensor<1xindex> -> !shape.shape
  %6 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<2xindex>) -> tensor<?x4xf32>
  %7 = mhlo.add %1, %6 : tensor<?x4xf32>
  return %4, %5, %4, %4, %1, %3, %6, %7 : !shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>

Finally, we could add some analysis based on the auxiliary function. Note that here the symbolic shape sources should be Values in the original function, not ones in the auxiliary function.

============= symbolic expr sources table for @several_ops =============
original value: %0 = "mhlo.dot_general"...
symbolic shape sources: 
<block argument> of type 'tensor<?x4xf32>' at index: 0

original value: %1 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
symbolic shape sources: 

original value: %2 = "mhlo.dynamic_broadcast_in_dim"...
symbolic shape sources: 
<block argument> of type 'tensor<?x4xf32>' at index: 0

original value: %3 = mhlo.add %0, %2 : tensor<?x4xf32>
symbolic shape sources: 
<block argument> of type 'tensor<?x4xf32>' at index: 0


  1. How to express constraints on the dynamic shapes? Since the auxiliary function in my current solution only address the calculation of a dynamic shape. The constraints could be useful if we meet cases like this:
a : [?, 100]
b : [?, 100]
c = concat(a, b, axis=0)

And if we know a.shape[0] + b.shape[0] = 1024 from annotation or deduction on the previous operations. Then we could know c has a static shape of [1024, 100].

  1. I found the reifyReturnTypeShapes could be implemented in a very flexible way. Instead of shape.shape_of, developers could use tensor.dim op (or any other ops I don’t know) to get the actual dimension of the operands’ shapes. Therefore I have to a new DimOp pattern to Shape Reification pass. Is it possible to limit the flexibility of reifyReturnTypeShapes interface?

  2. How to map the Value in the auxiliary function back to the original function? Since some passes will run on the auxiliary functions, which adds difficulty to the mapping.

1 Like

This is very timely! We have just been talking about this (well in general context, not MHLO one, with application to TOSA initially) the last couple of weeks. I’ll give a quick answers now, but will respond more completely later.

This is where meet comes in, but here you are asking (I think) more how we could take advantage of inference context info here. That seems more on the pass that uses the shape lib functions (well shape lib is the association mechanism we currently have upstream, I haven’t looked at yours yet).

It is more generic than it needs to be indeed and we should restrict it more, but I don’t think you need to add a new op to your reification, that op should be inserted in lowering to arith/tensor etc.

That is one where there are a couple of variants, in many of our uses it is currently done implicitly (we don’t just inline without context and so things are already strung along like needed so that we don’t need to keep that mapping in the IR), and there is an existing “shape associate” op that one can use (basically it ties the shape to the result). But I’m planning to send another op as suggestion “soon” and it will take inspiration from torch-mlir/TorchOps.td at 06750815d14371a7b8f0e964f2380517583e745f · llvm/torch-mlir · GitHub (now I’m still wondering a bit about overlap as now one has multiple ways of expressing it, so one could just be the higher level one)

1 Like

(llvm-project/test-shape-fn-report.mlir at main · llvm/llvm-project · GitHub for example of usage of shape.function_library op with shape association, here I had the association to ops in the mapping attribute of it while functions defined their own shape function association)

Thanks for your reply @jpienaar . Of course, the analysis should be a generic one, not bound to a single dialect. Here I choose mhlo dialect is only because many ops in mhlo implement reifyReturnTypeShapes interface.

And as you’re also considering this work, is there any design doc could be shared publicly?

Pls take a look at the DotGeneral’s reifyReturnTypeShapes, it directly use tensor.dim here. That’s why I added the DimOp pattern.

It seems the shape.function_library is mainly used to associate a shape inference function to an op. Still take this IR below as an example. Here a and b could be from the same op, or different ops. And we could infer that c.shape[0] = a.shape[0] + b.shape[0] according to the shape inference function. But the question is that how could we express we know a.shape[0] + b.shape[0] actually equals 1024, and how could we replace c.shape[0] with the static dimension size 1024?

a : [?, 100]
b : [?, 100]
c = concat(a, b, axis=0)

Hello all, we currently are also working on this topic. Followings are some our thoughts about this problem and preliminary code:

The basic idea is to use shape constraint IR to model both structure shape constraint (e.g. two symbolic dimensions are equal) and shape distribution constraint (e.g. shape range/likely value).

We found that both these two types of shape constraint would be useful for performance optimization in dynamic shape scenarios.

Simply relying on shape constraint analysis is not enough to explore both of these two kinds of shape constraints.

1 Like

A way (not necessarily most concise) is

%c = shape.const_size 1024
%a0 = shape.get_extent %a, 0 : !shape.shape, index -> !shape.size
%b0 = shape.get_extent %b, 0 : !shape.shape, index -> !shape.size
%sum = shape.add %a0, %b0 : !shape.size, !shape.size -> !shape.size
%cc = shape.meet %c, %sum : !shape.size, !shape.size -> !shape.size

that establishes the equality in the IR using the constraints, this would be injected by the analysis. Now as to replacement, that would be a propagation pass. The one I was playing around with what in the context of dependent type example I was working on and was quite simple just walking along meet nodes signifying equality. Coupling that simple propagation with canonicalization & CSE worked quite well, but I was not trying to capture algebraic terms in the propagation, for that you’d probably want map key’d of affine expressions in canonical form (commutative matching etc you may need additional work on, so potentially custom mapping would work better). You can look at JITRT in TF repo for a full fledge case I believe. Although if you know they are equal, then I’d expect you’d capture it in the inference context as you are propagating it and may not even materialize it, you’d just replace shape of concat with constant and be done.

That’s interesting, can you expand on this? And do you use the same info for both or do you have dedicated patterns for each?

(@yaochengji the tie_shape example that is in the Alibaba example is effectively shape.with_shape and was equivalent to shapex.tie_shape until it was removed , and a way of result association to shape that )

Not yet, we are approaching it from a different angle and in particular I’d want to avoid reification where possible. But I want to do a couple of prototypes before larger design.

The basic idea is shown as following.

func @main() {
  %0 = any_dialect.any_operation(...) : tensor<?x?xf32, [@S0, @S1]>

disc_shape.SymbolDim @S0 {
  range list : [[...], [...], ...]
  likely_values : [...]
  symbolic_shape_graph: @shape_constraint_graph

disc_shape.SymbolDim @S1 {
  range list : [[...], [...], ...]
  likely_values : [...]
  symbolic_shape_graph: @shape_constraint_graph

// A separated function to store shape constraint predicates between different symbolic dimensions.
// Each symbolic dim is either bound to a `disc_shape.dim` op or `disc_shape.bind_dim`
func @shape_constraint_graph(...) {
  %0 = disc_shape.dim() {ref: @S0} : index
  %1 = disc_shape.dim() {ref: @S1} : index
  disc_shape.tie_predicate_divisible(d0, d1) // d0 % d1 == 0
  // other tie_* ops
  //   disc_shape.tie_predicate_eq(d0, d1)  // d0 == d1
  //   disc_shape.tie_predicate_lt(d0, d1)  // dim less than
  //   disc_shape.tie_predicate_mul_eq(d0, d1, d2, ...) // d0 = d1 * d2 * ...
  //   // d0 * d1 = s0 * s1
  //   disc_shape.tie_predicate_group_mul_eq([d0, d1, ..], [s0, s1, ...])
  //   // d0 = affine.apply(d1, d2, ...) {affine_attr = ...}
  //   disc_shape.tie_predicate_affine_eq(d0, d1, d2, ...) {affine_attr = ...}

As shown in above IR, there are two kinds of shape constraints.

  1. Structure shape constraint. It is the predicate (relationship) between different symbolic dimensions. For example, the size of one dimension of a tensor is equal to the size of another dimension of the same tensor.

  2. Shape distribution constraint. For example,
    d0 % 4 == 0, d0 != 1, d0 in range[2, 100], d0 is more likely to be 4

We use shape constraint IR to explicitly encode these shape constraint information instead of relying on shape constraint analysis on data computation IR each time we need it due to following reasons.

  • some structure shape constraint or shape distribution constraints are injected by users. Such information can not be deduced from data computation IR directly. Thus we need dedicated IR to encode such information.
  • some shape constraint information may be lost during dialect conversion. During dialect conversion, we may decompose a high level op into some more fine-grained low level ops. The original shape constraint information captured by the semantic of the high level op may be lost. For example, tf.SplitOp will be lowered into a series of mho.RealDynamicSliceOp. According to the definition of the tf.SplitOp , we can know that all the outputs of the split op should have same shape and the split dimension of the input % (number of outputs) == 0. Depending on the concrete lowering strategy, we may lose one of the above two shape constraints. Using shape constraint IR, we can lower not only the data computation semantic but also the corresponding shape constraint information.

These are expressible in shape dialect today. likely isn’t. But there it is in operation, for these you have global symbolic constants here that you add into the tensor type as attribute, effectively serializing the inference context JITRT uses. Where all do you use this?

Does this mean it can have multiple closed intervals?

FYI: Few links from JitRt symbolic shapes:

  1. Analysis example: tensorflow/shape-component-analysis.mlir at 1f9ff31ed56a500245d7dabc095228b32cb143b4 · tensorflow/tensorflow · GitHub
  2. runtime/symbolic_shape.h at master · tensorflow/runtime · GitHub

We do care only about equalities to remove broadcasts (and also always materialize dim of size 1), but also interested in a more generic shape analysis.

1 Like

Yes, it could be multiple intervals. For example, a special case is a list of discrete numbers,[[2,2], [5,5]]

Some use cases that we are interested:

1, broadcast/reshape/… simplification. For example, remove redundant broadcast. Simplify mhlo.real_dynamic_slice op (e.g. find some dimensions that are actually fully sliced.)

2, Fusion decision. Making use of the shape constraint information, we can do better fusion decision in dynamic shape semantics, e.g., only fusing known shape-compatible ops into one kernel.

3, CodeGen optimization. For example, removing many redundant index calculation if we know some symbolic dimensions have the same size, which is important when the fusion pattern is relative large. Tiling/vectorization bound-check could be simplified if we know s0 % tile_size/vector_size == 0

4, using likely values to help us do better speculation decision (e.g. placement, layout, codegen schedule, implicit broadcast speculation).

We do this based on the output of the shape analysis that was linked earlier I believe. The rewrite patterns query the shape analysis. The main difference in approach is that we do not materialize the result of the shape analysis itself in IR but only properties of operations that we have inferred (like whether a broadcast is expanding, a reshape is just a collapse/expand, etc.).

We have seen that we need some caching of shape analysis results to avoid rerunning it but shied away from using the IR for this, as it requires that we keep that information correct when we rewrite. @wyzero have you seen this as an issue that arises in practice?

Our approach here has been that we assume dimensions line up for all HLO operations. We use an encoding with shape.assuming regions to guard operations with their prerequisites and then propagate those to form fusion islands.

Interesting. Do you have a code pointer for this? I am curious to see how this works. Do you make index computations explicit at the IR level and then cse dim operations (conceptually) where you know they are the same?

1, I agree that explicitly using IR to store shape constraint information may introduce more work when doing transformation to make the IR consistent. However, some kinds of shape constraints, for example, shape distribution constraints, can not be naturally captured by the semantics of data computation IR. These kinds of shape constraints are also useful in terms of performance optimization. On the other hand, we are not trying to store the calculation logic for one symbolic shape in another place again, instead we are just trying to store some high level predicates between different symbolics (e.g. a < b, a % b == 0, a * b = c * d). To reduce the overhead of maintaining the shape constraint IR, we also provide a shape optimization pass to automatically infer and propagate shape constraint information captured by the semantics of data computation IR (e.g. mhlo ops), thus the authors of the rewrite pattern only need to insert those lost shape constraints (like the above example: tf.split → milo).

2, index calculation simplification. The basic ideas is shown as following.

- Binding all symbolic-equal dimensions to the same SSA value

// Step #1: original input graph
func @main(%arg0 : tensor<?x?xf32, [@S0, @S0]>, %arg1 : tensor<?x?xf32, [@S0, @S0]>) -> tensor<?x?xf32, [@S0, @S0]> {
  // a specialized version of mhlo.add: squared add
  %ret = mhlo.add(%arg0, %arg1) : tensor<?x?xf32, [@S0, @S0]>
  return %ret : tensor<?x?xf32, [@S0, @S0]>

// Step #2: Explicitly binding all known symbolic-equal dimensions to the same SSA value right before going from tensor world to buffer world.

func @main(%arg0 : tensor<?x?xf32, [@S0, @S0]>, %arg1 : tensor<?x?xf32, [@S0, @S0]>) -> tensor<?x?xf32, [@S0, @S0]> {
  %c0 = constant 0: index
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32, [@S0, @S0]>
  %new_arg0 = disc_shape.tie_shape(%arg0, %d0, %d0) : tensor<?x?xf32, [@S0, @S0]>
  %new_arg1 = disc_shape.tie_shape(%arg1, %d0, %d0) : tensor<?x?xf32, [@S0, @S0]>
  %ret = mhlo.add(%new_arg0, %new_arg1) : tensor<?x?xf32, [@S0, @S0]>
  %new_ret = disc_shape.tie_shape(%ret, %d0, %d0) : tensor<?x?xf32, [@S0, @S0]>
  return %new_ret : tensor<?x?xf32, [@S0, @S0]>

// bufferize to lmhlo
//   - Symbolic dims are resolved into the same SSA value
//   - Enable more opportunities for normal CSE & canonicalization patterns.

func @main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> memref<?x?xf32> {
  %c0 = constant 0 : index
  %d0 = memref.dim %arg0, %c0 : memref<?x?xf32>
  %new_arg0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [%d0, %d0], …
  %new_arg1 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [%d0, %d0], …
  %ret = memref.alloc(%d0, %d0) : memref<?xf32>
  "lmhlo.add"(%new_arg0, %new_arg1, %ret)
  return %ret : memref<?x?xf32>
  • Example use case: eliminate redundant linearize/de-linearize pair in the fusion pattern (relying on the first step).

Suppose we have a fusion pattern abs -> reshape -> transpose -> exp and L(op) denotes the linear index of a op, multi-dim(op) denotes the multi-dim indices for the op. We can remove the last pair of linearize-and-delinearize (suppose start from the output).

The feature is still WIP, part of the code is listed as following:

  • Using the same SSA value for all known symbolic-equal dimensions whenever possible, link
  • Flatten mermen accesses explicitly, link
  • Linearize and Delinearize pair elimination, link


Here %c is a const, therefore the propagation pass could know that it’s better to replace %sum with %c.

But for a generic case, shape.meet %x, %y, how could we decide %x should be replace by %y, or vise versa?

Hey folks, I’m a little late to this thread, but I think you folks might find interesting what we are doing in Torch-MLIR. We have a pretty sophisticated shape refinement system (doc). It is similar to a lot of the approaches in this thread, but has a few characteristics of note:

  1. The shape functions are written in a restricted Python subset, which has many benefits:
    a. It is extremely easy to author, test, and debug the shape functions.
    b. They are shared with upstream PyTorch here.
    c. It is trivially easy to insert certain kinds of shape guards and assertions (and test them!)
  2. We found that a nontrivial amount of “list computation” and control flow is needed when dealing with operators which are allowed to have multiple ranks – for example iterating over all the batch dimensions of an operator and building the output shape.
  3. We reuse our TorchScript compiler to compile the shape functions themselves.
  4. Due to how we model our reified shape IR, identifying symbolically identical dimensions comes naturally from CSE after simplification.

Here example IR for a trivial unary function gist. Note that the shape function might look scary in the IR, but it is really just this trivial Python function link. Upstream PyTorch decided that they would like to make sure that the shape functions never return aliased lists, so we make a copy. Handling this falls out naturally from the more general types of shape functions we need to handle.

Here are some examples of some pretty interesting shape functions that we handle:
max_pool_2d - very complex logic and error checks
matmul - handles many different rank combinations
linear - demonstrates composition of other shape functions
transpose - conditional logic inside a loop is the natural way to express this shape function

In our case, we have to deal with a large and wild set of operators, and our shape functions are the Torch-level shape functions which have to be defensive against arbitrary user inputs. For a system like MHLO which is more orthogonalized and has more verifier-enforced properties, other choices probably make sense.

Thanks @wyzero for sharing the details.

Have you seen cases where a rewrite would make inferring constraints impossible? A case where the shape constraint propagation would need to coarsen a constraint, as it can no longer be derived from the program? Or would you, in such case, keep the previous, more precise constraints and just have weaker ones for rewritten IR?

I think the above mentioned tf.Split could be an example? we need to inject split_dim % num_outputs == 0 when lowering it to a bunch of mhlo real_dynamic_slice ops.