Range inference is an interesting topic, I’ll start with my personal conclusion and then unpack some of the details that led to this conclusion.
My personal TL;DR is that it is generally not worth the trouble for most of the cases we care about, for “named ops”, in TCP/Linalg. The implementation in https://reviews.llvm.org/D77067 expects the output shape (unlike in Tensor Comprehensions where range is inferred automatically).
def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
}
The output shape of C is explicit, parsed and saved (not yet connected to the “named” op but it will be).
This supports affine expressions today and can be extended in the future to take e.g. piecewise-affine or more general MLIR functions symbols.
Note that this is on a per-op basis, and it is possible to use ceil/floor/mod to express the output of a strided + dilated conv, for instance. I expect this specification will make it easy to evolve tensor → buffer translation to dynamic shapes on per-op basis. This hints at the fact that I don’t think a Linalg_CalculateResultShapesOp is the right way to go but I have not yet unpacked the implications of the existence of such an op.
This also hints at future extensions where we should make some memref type more structured and distinguish between say memref<?x?xf32>
and memref<N x (N / 2) x f32>
, but this is jumping the gun.
Details and Extra Considerations.
Disclaimer: This may or may not be relevant to what you are interested in but since I never really put these thoughts in words, I might as well write here.
On the topic of range inference, I have worked through these possible alternatives in the past. When taking a “loop-free” view of the world, this depends a lot on how one interprets the expression i+j
.
Take the example of a simple 1-D conv: O(i) += I(i + j) * W(j)
.
Side note: once loops are introduced, these considerations are not required anymore (but at that point buffer allocation is already done too).
Triangular Interpretation
One possible interpretation is:
for i = 0, dim(I, 0), 1:
for j = 0, min(dim(W, 0), dim(I, 0) - i), 1:
O(i) += I(i + j) * W(j)
This would an interpretation of:
{forall i, exists j, 0 <= i, 0 <= j, 0 i + j <= dim (I, 0), j <= dim (W, 0)}
This is the type of loop nest you would get by using classical polyhedral tools like cloog / ISL given a set specification. This essentially does:
- project along i in the parameter space
(dim(I, 0), dim(W, 0) )
- project along j in the parameter space
(dim(I, 0), dim(W, 0), i)
Order of projection matters: the loop bounds will be different (but will of course agree with each other) depending on whether you write for i, for j
or for j, for i
. This is related to loop interchange on triangular loops (see e.g. the Allen & Kennedy textbook).
In this context, O has shape dim(I, 0)
.
HyperRectangular Interpretation
This is the interpretation we chose for the loop-free Tensor Comprehensions.
This corresponds to a “quantifier inversion”.
for i = 0, dim(I, 0), 1:
for j = 0, dim(W, 0), 1:
O(i) += I(i + j) * W(j)
This would an interpretation of:
{exists i, forall j, 0 <= i, 0 <= j, 0 i + j <= dim (I, 0), j <= dim (W, 0)}
Any order of projection gives you the same loop nest (this is why TC defines ordering of indices to be irrelevant in the computation specification (modulo FP precision in reductions that is conveniently ignored)).
In this context, O has shape dim(I, 0) - dim(W, 0) + 1
.
It is interesting to note that affine.parallel_for
can only represent hyper-rectangular sets (see my comment in this thread, in particular):
### IR Representation
The key thing I see here, and that I think may have been missed, is that you
have adopted a notion of **HyperRectangular, Parametric Loop Band** :
within a single `paralel_for` op, you cannot read the induction variables
you define.
Inference Procedure
Now the above is a trivial example that should give an intuition. As usual there are tradeoffs and complexities involved. These start to become annoying when multiple solutions are possible and need to agree with each other. These show up with multiple parameters and when one needs to consider different parametric contexts (see the notion of “quasts” in the Feautrier literature).
To perform inference, at least these options exist (need to be adapted depending on the triangular / hyperrectangular interpretation above):
- give a set to ISL, project along dimensions (in a particular order if relevant),
- use an ILP to find the min / max of some quantity along some dimension under some constraints (need to be adapted to the parametric “quast” case)
- choose the hyper-rectangular interpretation and implement a type of Gaussian elimination procedure, introducing min.max as appropriate. Here is a “description by example” of what happens in tensor comprehensions.
In practice, I think any of those can lead to surprising behaviors in corner cases which makes it a difficult user experience: the correctness of your program depends on a non-trivial procedure.
Bottom line: let’s specify per-“named op” output shapes for now in TCP / Linalg and revisit later when we have concrete evidence that this is not enough. This is similar to defining a library call + a shape calculation function, except it is naturally encoded in the named op.