Computing output shapes of structured ops on tensors

When lowering linalg from tensors to buffers, one arrives at the problem of allocating the result buffer. For dynamic shapes, this involves computing the dynamic size of the linalg.generic’s results.

I currently have an op that looks like:

def Linalg_CalculateResultShapesOp
    : LinalgOp<"calculate_result_shapes", [NoSideEffect]> {
  let summary = "Calculate the result shape for a structured op.";
  let description = [{
Given a set of indexing maps as would be passed to `linalg.generic`,
calculate a shape for the resulting tensors.

As with `linalg.generic`, the indexing maps are a flat list which correspond
first to the input operands and then to the output operands.

// TODO: Are any other of the linalg.generic attrs needed to calculate the
// output shapes?
  }];
  let arguments = (ins
    Variadic<Shape_ShapeType>:$operands,
    AffineMapArrayAttr:$indexing_maps
  );
  let results = (outs Variadic<Shape_ShapeType>:$results);
}

Does this look correct? Do the iterator types matter for computing the output size?

I also have no idea how to lower this to basic math ops… I would really appreciate some help to figure out how to define and lower this op. I’m running into this with the numpy/TCP end2end prototype, but I already see a place that it is needed in this pull request adding linalg-on-tensors to linalg-on-buffers conversion for XLA, which fundamentally assumes static shapes because of this.

@nicolasvasilache, @dfki-mako (@dfki-ehsa doesn’t seem to be registered here; could you loop them in?)

Range inference is an interesting topic, I’ll start with my personal conclusion and then unpack some of the details that led to this conclusion.

My personal TL;DR is that it is generally not worth the trouble for most of the cases we care about, for “named ops”, in TCP/Linalg. The implementation in https://reviews.llvm.org/D77067 expects the output shape (unlike in Tensor Comprehensions where range is inferred automatically).

def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
}

The output shape of C is explicit, parsed and saved (not yet connected to the “named” op but it will be).
This supports affine expressions today and can be extended in the future to take e.g. piecewise-affine or more general MLIR functions symbols.

Note that this is on a per-op basis, and it is possible to use ceil/floor/mod to express the output of a strided + dilated conv, for instance. I expect this specification will make it easy to evolve tensor → buffer translation to dynamic shapes on per-op basis. This hints at the fact that I don’t think a Linalg_CalculateResultShapesOp is the right way to go but I have not yet unpacked the implications of the existence of such an op.

This also hints at future extensions where we should make some memref type more structured and distinguish between say memref<?x?xf32> and memref<N x (N / 2) x f32>, but this is jumping the gun.

Details and Extra Considerations.

Disclaimer: This may or may not be relevant to what you are interested in but since I never really put these thoughts in words, I might as well write here.

On the topic of range inference, I have worked through these possible alternatives in the past. When taking a “loop-free” view of the world, this depends a lot on how one interprets the expression i+j.

Take the example of a simple 1-D conv: O(i) += I(i + j) * W(j).

Side note: once loops are introduced, these considerations are not required anymore (but at that point buffer allocation is already done too).

Triangular Interpretation

One possible interpretation is:

for i = 0, dim(I, 0), 1:
  for j = 0, min(dim(W, 0), dim(I, 0) - i), 1:
    O(i) += I(i + j) * W(j)

This would an interpretation of:
{forall i, exists j, 0 <= i, 0 <= j, 0 i + j <= dim (I, 0), j <= dim (W, 0)}
This is the type of loop nest you would get by using classical polyhedral tools like cloog / ISL given a set specification. This essentially does:

  1. project along i in the parameter space (dim(I, 0), dim(W, 0) )
  2. project along j in the parameter space (dim(I, 0), dim(W, 0), i)

Order of projection matters: the loop bounds will be different (but will of course agree with each other) depending on whether you write for i, for j or for j, for i. This is related to loop interchange on triangular loops (see e.g. the Allen & Kennedy textbook).

In this context, O has shape dim(I, 0).

HyperRectangular Interpretation

This is the interpretation we chose for the loop-free Tensor Comprehensions.
This corresponds to a “quantifier inversion”.

for i = 0, dim(I, 0), 1:
  for j = 0, dim(W, 0), 1:
    O(i) += I(i + j) * W(j)

This would an interpretation of:
{exists i, forall j, 0 <= i, 0 <= j, 0 i + j <= dim (I, 0), j <= dim (W, 0)}
Any order of projection gives you the same loop nest (this is why TC defines ordering of indices to be irrelevant in the computation specification (modulo FP precision in reductions that is conveniently ignored)).

In this context, O has shape dim(I, 0) - dim(W, 0) + 1.

It is interesting to note that affine.parallel_for can only represent hyper-rectangular sets (see my comment in this thread, in particular):

### IR Representation

The key thing I see here, and that I think may have been missed, is that you 
have adopted a notion of **HyperRectangular, Parametric Loop Band** : 
within a single  `paralel_for`  op, you cannot read the induction variables 
you define.   

Inference Procedure

Now the above is a trivial example that should give an intuition. As usual there are tradeoffs and complexities involved. These start to become annoying when multiple solutions are possible and need to agree with each other. These show up with multiple parameters and when one needs to consider different parametric contexts (see the notion of “quasts” in the Feautrier literature).

To perform inference, at least these options exist (need to be adapted depending on the triangular / hyperrectangular interpretation above):

  1. give a set to ISL, project along dimensions (in a particular order if relevant),
  2. use an ILP to find the min / max of some quantity along some dimension under some constraints (need to be adapted to the parametric “quast” case)
  3. choose the hyper-rectangular interpretation and implement a type of Gaussian elimination procedure, introducing min.max as appropriate. Here is a “description by example” of what happens in tensor comprehensions.

In practice, I think any of those can lead to surprising behaviors in corner cases which makes it a difficult user experience: the correctness of your program depends on a non-trivial procedure.

Bottom line: let’s specify per-“named op” output shapes for now in TCP / Linalg and revisit later when we have concrete evidence that this is not enough. This is similar to defining a library call + a shape calculation function, except it is naturally encoded in the named op.

1 Like

Thanks Nicolas, that’s a lot more complex than I thought. For now I’ll handle the special cases that I see!

I’m not quite seeing the layering here, since I thought that we would do lots of fusion and stuff at the tensor level which generally fuses named ops into generic ops, thus losing this shape information when we later want to lower to buffers.

Maybe you can explain more how you envision the larger-scale compilation pipeline to work in this regard?

Very good point, the discussion above does not cover linalg.generic.

I am unclear at this time whether this would warrant 1. adding a shape ArrayAttr to linalg.generic on tensors, or 2. using a range inference algorithm as we did in the past (with all the disclaimers involved) or 3. something else.

Great. I’ll add it to the list of “things we won’t really be able to judge without a good e2e prototype” :slight_smile:

As prior art here (for folks that don’t know), in IREE we basically require that all ops in a “dispatch region” must get fused into a single linalg.generic, and the dispatch region formation ensures that and also does all the buffer size calculations at the ~“named ops” level and passes pre-allocated buffers to the dispatch regions.

In this interpretation aren’t we doing out of bounds access on I? Or is the outer loop supposed to be until dim(I, 0) - dim(W, 0) + 1?

Correct iterator i is in the range of O, I’ll update when back on my laptop. Thanks for spotting!

If we were to use shape functions that are annotated at linalg operations, then I would expect these functions to be parameterized over the shapes and maybe values of input tensors for that operation. If they depend on the values of the input tensor, then that would likely inhibit fusion, as intermediate result values would have multiple references in your program, so fusion would duplicate computation.

If they only depend on shapes, then those computations can be composed independently of the computations of the values. So you can derive the shape of the output of a fused linalg operation by composing the shape functions for the original inputs with the shape function for the consumer linalg operation.

Ultimately, you end up with two separate streams of computation, one for shapes and one for the values. Buffer allocation forces these two to intercept (as do value dependencies on shapes).

2 Likes

Good point! My concern is that once we e.g. fuse the linalg ops into an op whose output shape cannot easily be later derived, we need to maintain the shapes associated with their respective tensors in a way that doesn’t inhibit or overcomplicate other transformations.

One thing we are recently finding in IREE is that IR which has materialized shape functions (using a tie_shape approach) is generally quite fragile and we only want to materialize it at specific points right before the passes that really need that information and then remove it. Effectively, shape-annotated IR seems to be a sublanguage with invariants that are difficult to maintain in the presence of arbitrary passes.

As a simple example, even a simple canonicalization pattern or lowering won’t work, because the pattern doesn’t know how to materialize the shape computation for any ops it produces to re-establish the invariants of the “shape-annotated computations on tensors sublanguage”.

@stellaraccident who did most of the work on this on the IREE side.

What Sean says. When I first started IREE’s shape work, it was with the goal exactly as Stephan says: to have two streams of computation that only merge when needed (for buffer allocation, interfaces to backends, etc). To get going for POC things, I had one big pass at the beginning that would materialize any dynamic dimensions and associate them with producer values via a tie_shape op. However, like Sean says, this puts severe constraints on transformations (which I knew but wanted to get something going).

Now, what we are doing is just doing that materialization around transforms that matter. For us, dispatch/stream formation and buffer allocation – all of which are fairly structured and can be taught how to correctly reason about this “shape-tied-meta-IR” correctly. Also, materializing everything around this phase gives us a chance to eagerly dedup producers for unknown dimensions and outline them into dispatch regions explicitly.

When proceeding, though, we then remove all of the tie_shapes, only leaving them at “root” producers (ie. entry block args for outlined dispatch functions). Given the constraints in what we allow in dispatchable functions, it should always be possible to reconstitute any unknown dimensions from these root associations. Effectively, this means that we should be able to always legalize std.dim and shapex.get_ranked_shape to some IR that trivially derives from the root ties to dynamic dimension inputs.

This means that for the most part, on-demand dimension resolution is still allowed during transformations. However, once we lower past a point where these concepts no longer apply, then we can no longer introduce new std.dim/shapex.get_ranked_shape ops during transformations. I believe this point of no return is whenever we enter a conversion phase which may produce ops that do not implement the shape-analysis interface for materializing their shapes. After that point, it becomes non-deterministic as to whether we can re-associate unknown dimensions with the function inputs that ultimately provide them.

Let me know if any of that is not clear and I can prepare something with sample IR to look at which illustrates the various phase ordering/transformation issues and what we’re doing to overcome them. It’d be good to get another set of eyes on it at some point too, as I may be missing some simplification.

1 Like