Shape constraint representation

In MLIR, we could represent (partial) dynamic shape easily (e.g. tensor<?x?xf32>). This feature is really helpful when supporting dynamic shape computation graph lowering and compilation.

In some cases, we may know the relationship between the unknown dimensions even though we do not know their actual values.

case #1: shape constraint of a tensor

suppose we have %0 = ... : tensor<?x?xf32>, and we know the first dimension and the second dimension are equal.

case #2: shape constraint between different tensors

%0 = ... : tensor<?x?xf32>

// these two slice ops are lowered from a split-like op (e.g. divide a tensor evenly)

%1 = slice(%0, ...) : (tensor<?x?xf32>) -> tensor<?x?xf32>

%2 = slice(%0, ...) : (tensor<?x?xf32>) -> tensor<?x?xf32>

we know %1 and %2 have the same shape and want keep such information during the whole lowering pipeline in order to generate more efficient code.

The problem is:
1, some shape constraints are produced in the graph-level dialect (e.g. the example in case 2), and are consumed in mid-level or even more lower level (e.g. when doing codegen), thus we need to carry these information through the whole lowering pipeline (across different dialects).

2, these shape constraints are just hints for optimization, and should not break the passes that don’t know them.

3, these shape constraints are dialect-neutral, and should be able to play with different dialects.

Any advice about how to achieve this?

Some possible solutions:
1, new tensor type with special attributes, and using attributes to record shape constraints.
2, extend shape dialect, and add some “annotation” IR to represent these shape (but how do we keep such annotation during transformation?).

Thanks for starting this thread. I have thought a bit about this recently in the context of XLA-like dynamic shapes, where certain dimensions are dynamic but constrained (size <= compile_time_constant). As you mentioned, adding annotations in the IR does not work because these annotations can get dropped at any time. I have been thinking along similar lines as you:

  1. Adding these constraints to tensor and memref types. So both tensor and memref types with a dynamic dimension (x?) can have additional constraints attached to them in the form of an IntegerSet attribute (one or more affine equalities or inequalities). Being baked into the type, they will not get dropped and can be preserved throughout the lowering pipeline. We need to discuss if we need to restrict the kind of integer inequalities that can be attached as constraints.

  2. Existing transformations will continue to work as is, and then we can gradually change them to look into adopting these constraints. As an example, bufferization can merge several memref’s with known size constraints into a single larger memref to model XLA style buffer allocation where all intermediate buffers are allocated from a single temporary buffer. Or GPU code generation can potentially choose whether or how much shared memory to use based on the know size limits for one or more of these memrefs.

This does not address the second case though, where the constraint is between two tensors/memref’s.

1 Like

Hey,

So this is a big reason why we have the shape dialect :slight_smile: A lot of these also fit into inference context and the corresponding analysis (which is more a sketch than actual code at the moment), it also fits in with the assuming regions there (that can capture that equality).

Encoding into the type anything more than plain upper bound with a constant isn’t too desirable, and it makes the type more unwieldy & constrains expression. Currently what we need and plan to support is symbolic equality constraints and upper bound by constants, but we don’t want to hard code that as the only ability.

Now, we can of course have bounded tensors in a dialect - going back to enabling tensor type to have an attribute for layout would also cover this and that was the original plan, but the layout attribute is still pending (it should IMHO be an Attribute).

I think there are parts of the solution that require both

  1. extend Tensor type to have a layout attribute (at least for the simple case of bounded by compile time constant you can represent as part of layout and it makes sense as part of that/memory planning), I would not want symbolic constants or the like there that requires type rewriters to know how to manage shape inference/layout behavior & propagate it.
  2. shape dialect has explicit ops, these can represent these constraints, and would persist through optimizations, and if lowered to assuming regions, could be out of the way. Now those you don’t want to operate on directly in all kinds of queries, instead having an analysis with inference context should be created (same as how one would create symbol table rather than walking over module often). The inference context can then be queried for the equalities/inequalities/relationships, it need not be fixed with what we can represent in the type system and there need not be only one (you can either use plain static inference context, symbolic equality, static bound, symbolic equality & static bound, anything-that-can-be-represented-with-z3 context, …) depending on needs. Then requirement is on use and we don’t have to use a hammer if not needed. Now we might want to make a variant of the assuming op at some point too (but that is more about making concise, rather than representational ability).

The context then also addresses need of spanning tensor/memref. The downside to it is

  1. there are (optionally) operations there that interfere with optimization patterns (less of an issue with the assume regions, but those also could also become unwieldy and they can only represent additive structure easily - again fine for most cases though),
  2. one needs to query the analysis & potentially update it (now, updating is required with types too, and types require that all builders know how to propagate the new layouts, so this just moves responsibility to one element).

I think it would also be helpful in this context to think about what the shape functions for the slice operation would look like. The current approach we take with the shape dialect is based on the assumption that we can derive the knowledge we want from what the shape functions provide.

Can you provide a bit of context on this @wyzero? I’d expect the shape function to be dependent on the shape of %0 bur also the other arguments to slice in this example. So what do those look like?

I am also looking at a similar problem. I want to do shape reification (i.e. lowering shape functions into the shape dialect ops) in an independent pass without doing bufferization. (The current uses of shape reification are all in some bufferization pass and connect the reified shape to an alloc op.) I think the common theme is that there isn’t a solution to use shape dialect and some tensor-based dialect in a mix, and then perform some other transformations before the buffer allocation.

2, extend shape dialect, and add some “annotation” IR to represent these shape (but how do we keep such annotation during transformation?).

Currently, I am actually exploring this route. What I am having been playing with is to define a foo.shape_and_eval op with two regions that provides a structural association between the two results of the actually value and the shape of the value.

  %output:2 = "foo.shape_and_eval"() (
    {
      // Output shape calculation
      %input_shape = shape.shape_of %input : tensor<?x?xf32> -> !shape.shape
      // shape calculation with shape dialect ops...
      %inferred_output_shape = ...
      foo.yield %inferred_output_shape : !shape.shape
    }
    {
      // Output value calculation
      %output_value = "bar.slice"(%input) {...} (tensor<?x?xf32>) -> tensor<?x?xf32>
      foo.yield %output_value : tensor<?x?xf32>
     }
  ) : () -> !shape.type, tensor<?x?xf32>

I imagine some tensor type based transformations are possible (e.g. merging several foo.shape_and_eval ops and their regions), but they do need to be aware of this new op, so it doesn’t meet your requirement 2.

Thanks for your reply. Currently we are working on dynamic shape graph compilation. The slice op looks like following in my case:

// data_input, start_indices, limit_indices and stride_indices are all values not attributes.
%out = slice(%data_input, %start_indices, %limit_indices, %stride_indices)

The current approach we take with the shape dialect is based on the assumption that we can derive the knowledge we want from what the shape functions provide.

This assumption may be fine if we could define enough ops to capture the high level intention. For example, we use a split op instead of using a series of slice ops to represent the same semantic. However, this approach may not be scalable especially in multi-level IR situation.

This design reminds me a bit about what IREE is doing with their tie_shape operation (I might misremember the name). That does not use regions but otherwise serves the same purpose. Maybe @stellaraccident has some pointers.

I agree that tying shapes to values is more difficult in tensor land, as we do not have an explicit operation to do so. When using buffers (independently of whether it is being used dying bufferization or not), the alloc operation provides this natural connection point where the shape computation is used to define the output value before it is written.

When using tensors, the question is why one would want these operations in the IR in the first place. If only for the sake of analysis, one could also have an analysis pass that interprets the shape computations in some form. So for every operation, you could reify the shape computation and evaluate it under some partial value domain or do symbolic analysis. Some caching of the shape IR might help for performance and would ultimately give you the same as above (maybe a side module with all the shape functions) but a looser tie of ops to these functions.

If you want to use CSE and canonicalization for rewriting shape computations, then materializing them in IR is useful but the region based approach might hinder that, as you would need to lift computations out of the regions to make them visible for CSE.

Definitely an interesting design space to explore!

I agree. The higher the operation semantics, the easier it is to derive knowledge because it can essentially be hard-coded. But let’s play a little. Assuming we had a shape function like the below (no claim for correctness):

slice_shape(%data_input, %start_indices, %limit_indices, %stride_indices) {
  %shp_lb = shape.from_extent_tensor(%start_indices)
  %shp_ub = shape.from_extent_tensor(%limit_indices)
  %shp_st = shape.from_extent_tensor(%stride_indices)
  %size = shape.subtract %shp_ub, %shp_lb
  %result = shape.floordiv %size, %shp_st
  return %result
}

and two calls

%data = ...
%shape = shape.shape_of %data // we need the shape as a vector, missing in std
%rank = rank %data
%zeros = dyn_splat 0, %rank // we need a way to express dynamic splats 
%ones = dyn_splat 1, %rank 
%twos = dyn_splat 2, %rank
%half = floor_div %shape, %twos
%split_l = slice(%data, %zero, %half, %ones)
%split_r = slice(%data, %half, %shape, %ones)

if we now insert the arguments into the shape functions, we get

%shape = shape.shape_of %data // we need the shape as a vector, missing in std
%rank = rank %data
%zeros = dyn_splat 0, %rank // we need a way to express dynamic splats 
%ones = dyn_splat 1, %rank 
%twos = dyn_splat 2, %rank
%half = floor_div %shape, %twos
// First shape.
%shp_lb1 = shape.from_extent_tensor(%zeros)
%shp_ub1 = shape.from_extent_tensor(%half)
%shp_st1 = shape.from_extent_tensor(%ones)
%size1 = shape.subtract %shp_ub1, %shp_lb1
%result1 = shape.floordiv %size1, %shp_st1
// Second shape.
%shp_lb2 = shape.from_extent_tensor(%half)
%shp_ub2 = shape.from_extent_tensor(%shape)
%shp_st = shape.from_extent_tensor(%ones)
%size2 = shape.subtract %shp_ub2, %shp_lb2
%result2 = shape.floordiv %size2, %shp_st2

We need the from_extent_tensor to canonicalize away, e.g., we want the operations to be replaced by shape dialect counterparts. So the ones, zeros and twos should become constants in the shape dialect. With that, the shape.floordiv would be folded. Giving us

%shape = shape.shape_of %data
%rank = shape.rank %shape
%zeros = shape.splat 2, %rank
%twos = shape.splat 2, %rank
%half = shape.floordiv %shape, %twos
%result1 = shape.subtract %half, %zeros
%result2 = shape.subtract %shape, %half

The first shape, %result1, trivially becomes half. The second shape is defined as shape - (shape floor_div 2). That can only be rewritten if we know that shape is divisble by 2.

That would be the same property we would need on a split operation. We could have used a different formulation that is independent of that property. For example, we could have written

%data = ...
%shape = shape.shape_of %data // we need the shape as a vector, missing in std
%rank = rank %data
%zeros = dyn_splat 0, %rank // we need a way to express dynamic splats 
%ones = dyn_splat 1, %rank 
%twos = dyn_splat 2, %rank
%half = floor_div %shape, %twos
%two_halves = mul %half, %twos
%split_l = slice(%data, %zero, %half, %ones)
%split_r = slice(%data, %half, %two_halves, %ones)

Now this always returns two results of the same shape and filling this into the above formulation of the shape function, we would get

%result2 = shape.sub %two_halves, %half

and we would need to simplfiy (mul %half, %twos) - %half into %half` which seems reasonable.

So, as this shows, one needs to be careful how the semantics of operations are defined, e.g., for a split operation that produces equally sized chunks. But this can be lowered to slice if one is careful and considers shape computations in the process.

Whether this scales to more complex operations is a good question. It is a similar question to how far one can take affine modelling. In the end, I believe the important property of the system is whether it degrades gracefully when knowledge is lacking.

Ah, for me, I want to do this because I am trying to do loop tiling on tensors instead of on buffers. Therefore, deferring shape reification post tiling is going to be riskier (not sure if anyone has done something similar on the MLIR land) and more work (need to do shape computation/analysis across loop nests).

Side comment: (not to divert this thread) I suspect that you really want to sync with @nicolasvasilache, @MaheshRavishankar, @_sean_silva, and @benvanik, who are working on this right now in the context of linalg in tensors.

Ended up with a new post after it has grown more than I intended to, in a first cut: Linalg and Shapes