Here are some high-order bits that have been tricky when using Linalg on tensors.
This relates to the ongoing discussion on discourse as well as the IREE discussion.
@stellaraccident @jpienaar @_sean_silva @benvanik @hanchung
Buffer land
All inputs and outputs are materialized SSA values with type memref<?x4x?x42xf32, layout>
type. At this point, transformations have enough info to create subsets of data and computations using existing infrastructure that composes and canonicalizes properly.
Note that certain operations use a “fake” memref / a hollow memref whose data payload is ignored. linalg.max_pooling ins(%I: memref<?x?x8x16xf32>, %W: memref<2x3xf32>), outs(%O: memref<?x?x7x14xf32>
. This essentially captures shape information after observing that a pooling op is a conv op with a functional kernel.
Tensor land
Currently, all inputs are materialized but only a subset of outputs (i.e. init_tensor
). These are represented as SSA values with type tensor<?x4x?x42xf32>
.
Linalg ops often (but not always…) have enough information to derive/reify/infer the output shape form the inputs. Examples: linalg ops that represent the computation O(i,j) <- f(A(i, j), B(j, i))
, pointwise + permute ops, certain broadcast semantics O(i, j) <- f(A(i), B(j, i, k))
etc.
Certain ops may look like they have enough information but they really don’t.
Examples:
linalg.constant(%0) : tensor<4x8xf32>
“works” but
linalg.constant(%0) : tensor<?x8xf32>
lacks information.
linalg.broadcast(%0) : tensor<4xf32> into tensor<4x8xf32>
“works” but
linalg.broadcast(%0) : tensor<4xf32> into tensor<4x?xf32>
lacks information.
* It is left to the reader to derive the full linalg.generic
semantic to implement constant
and broadcast
, these implementation details are not relevant.
The above creates a landscape in which:
- many but not all ops have enough information to implement a simple inference and not worry about specifying output shapes.
- it feels unsatisfactory to have to specify say 10 output shapes (can easily happen after fusion of pointwise) that can all be derived automatically.
- it is unreasonable to define op semantics / verifier based on whether the shape inference procedure (a Gaussian elimination process) succeeds or not (i.e. whether the map between shapes and loops depends on any output dimension). We already had the problem in Tensor Comprehension land and did not come to a general and unsurprising solution that can be exposed to a user.
Shape dialect
The shape dialect type is rank, elemental type and dimension erased: !shape.shape
. All static information is captured via op semantics + IR structure (and without type or attributes information). This is factored via InferTypeOpInterface
.
Stepping forward
linalg.generic
is a common lower-level form that has to support all cases, it must not be concerned with implicit shape inference rules: explicit is better than implicit. linalg.generic
needs to require all output information to be available to support all cases. It is counterproductive to try and be “smart” or “nice to use” at this level of abstraction. Shape inference is a concern that must have been resolved already. If one wants nicer user semantics it is the role of named ops to provide a form in which only the necessary information needs to be spelled out / custom shape inference behavior can be specified.
An easy path forward is to just force all linalg.generic
to take an output tensor for all output shapes and have an op to create a tensor out of a list of SSA values that represents the shape. This will allow unifying the structured ops interface and make it less surprising in tensor land.
A big advantage is that all transformations, ops and canonicalizations will work out of the box operating on tensor<?x4xf32>
. This will also unify the tensor and memref behavior which currently have subtle OpInterface differences.
This is by no means enough: every tensor will be bufferized to a new alloc
when there is no need to. In the short-term, a simple analysis can traverse the IR and determine whether allocations are indeed necessary and try to avoid them. But this hints at making the “hollow memref” type mentioned in the (Buffer land paragraph) a first class citizen and not lose the information.
This could be a new ShapedType e.g. ranked_shape
or more simply an extra bit on existing shaped type and would ensure aligned_ptr == alloc_ptr == nullptr
.
E.g. with just a bit we could have tensor<?x4xf32, shape_only=1>
and memref<8x?xf32, shape_only=1>
. Values of such types could be used in the shape dialect and bridge the gap to linalg
/ subview
/ subtensor
/ subtensor_insert
. memref<..., shape_only=1>
may not be loaded from / stored into. The nice thing is that such a type extension would compose naturally with all ops, canonicalizations and transformations. But this now gets into shape dialect territory. The downside is that it a static form of shape that the shape dialect does not have (yet) and it is unclear this is a desirable addition.
A better path forward at that point may be, as usual, to use OpInterfaces and factor out the properties needed for linalg
op representation and transformations. It is unlikely however that linalg
will want to evolve into support rank-erased ops at this point so we would probably need something like:
shape.to_hollow %s: !shape.shape, tensor<?x4xf32, shape_only=1>
shape.to_hollow %s: !shape.shape, memref<8x?xf32, shape_only=1> // no layout
Many longer-term design alternatives are possible and relate to the traditional “represent static information with a) type, b) attribute, c) op semantics, d) structure in the IR”, as well as canonicalizations, shape_cast
propagations, how that mixes with transformations and how much we want to traverse the IR to retrieve information necessary for transformations vs encode it in types.