Based on some previous discussions (here), there is an interest to make Linalg dialect less monolithic, and to piecewise graduate things out of it. As part of this effort, this RFC was proposed to deprecate linalg.init_tensor
operation. The RFC has been abandoned for now, but it did bring to the fore some issues with respect to how
linalg.init_tensor
is used currently. In its current state, linalg.init_tensor
cannot be easily moved out of the Linalg dialect because it is used in two modes
- To create a tensor where its uses only need the shape and element-type of the tensor.
- To create an tensor whose values are undefined, into which values/tensors are inserted using destructive updates.
These are two separate use cases at different levels of the stack, which need to be decoupled before linalg.init_tensor
can be deprecated.
The second part is more relevant to transformations like tiling on tensors and bufferization that are typically used lower down the stack. This RFC is not related to this use case. Indeed [this RFC] [RFC] Promoting `linalg.init_tensor` to the `bufferize` dialect) was intended to adapt linalg.init_tensor
to better address this use case. Decoupling the two uses cases will make the solution proposed in that RFC viable.
This RFC is intended to address the first use case . This is typically relevant for front-end dialects (like MHLO and TOSA) and their lowering into Linalg on tensors. For example linalg.pooling_nhwc_sum
operation (and other variants of this) are represented in IR as:
%window = linalg.init_tensor ... : tensor<?x?xf32>
%output = linalg.pooling_nhwc_sum
ins(%input, %window : tensor<?x?x?x?xf32>, tensor<?x?xf32>)
outs(%output_init : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
The %window
value here is only used to determine the loop bounds of one or more loops of the operation. Its value is not used as part of the semantics of the operation.
Another example is from linalg.generic
operations with broadcasting semantics.
%operand = linalg.init_tensor ... : tensor<?x?xf32>
%result = linalg.generic {
iterator_types = ["parallel", "parallel"],
indexing_maps = [affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>]}
ins(%input : tensor<?xf32>)
outs(%operand : tensor<?x?xf32>) {
^bb0(%arg0 : f32, %arg1 : f32) :
linalg.yield %arg0 : f32
} -> tensor<?x?xf32>
Here %result
is a tensor
created by broadcasting values from %input
. The shape of %result
is same as %operand
, but the values in the %operand
tensor is not used by the operation. This is indicated by the fact that %arg1
of the region of the Linalg operation is not used. (The method
payloadUsesValueFromOperand
of the LinalgInterface
allows an easy way to find such operands.)
In all such cases a linalg.init_tensor
operation is used to create the tensor
value without associating any value with its elements. A better modeling of this though is through use of a separate type which allows you to define only the shape and element-type of a tensor. This RFC is aimed at addressing this by adding an abstract_tensor
type to the builtin
types. The builtin.abstract_tensor
type would be similar to builtin.tensor
type, but represents a tensor shape without associating any data with its elements. A new operation will be added to the tensor
dialect that creates a value of abstract_tensor
type.
%0 = tensor.create_abstract_tensor [%d0, %d1] : builtin.abstract_tensor<?x?xf32>
creates a 2D abstract_tensor
(of dimensions %d0
and %d1
) with element type f32
.
The above IR examples would be represented as
%d0 = arith.constant ... : index
%d1 = arith.constant ... : index
%window = tensor.create_abstract_tensor [%d0, %d1] : abstract_tensor<?x?xf32>
%output = linalg.pooling_nhwc_sum
ins(%input, %window : tensor<?x?x?x?xf32>, abstract_tensor<?x?xf32>)
outs(%output_init : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%d0 = arith.constant ... : index
%d1 = arith.constant ... : index
%result_shape = tensor.create_abstract_tensor [%d0, %d1] : abstract_tensor<?x?xf32>
%result = linalg.generic {
iterator_types = ["parallel", "parallel"],
indexing_maps = [affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>]}
ins(%input : tensor<?xf32>)
outs(%result_shape : abstract_tensor<?x?xf32>) {
^bb0(%arg0 : f32, %arg1 : f32) :
linalg.yield %arg0 : f32
} -> tensor<?x?xf32>
Work plan
To not make sweeping changes the steps would be as follows
-
Add the
builtin.abstract_tensor
type andtensor.create_abstract_tensor
operation. -
Replace all cases where result on
linalg.init_tensor
is used inins
list of Linalg operations, i.e. allow use ofValue
of typebuiltin.abstract_tensor
inins
list of Linalg operations. This addresses uses like thelinalg.pooling_*
operations above. -
Replacing uses of
linalg.init_tensor
inouts
list of Linalg operation, i.e. allow use ofValue
of typebuiltin.abstract_tensor
inouts
list of Linalg operations. This can have impact lower down the stack. It ties into the use case wherelinalg.init_tensor
is used as a tensor with undefined values. Specifically tiling Linalg ops on tensors usingscf.for
uses destructive updates. For example,%operand = linalg.init_tensor ... : tensor<?x?xf32> %result = linalg.generic { iterator_types = ["parallel", "parallel"], indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>]} ins(%input : tensor<?xf32>) outs(%operand : tensor<?x?xf32>) { ^bb0(%arg0 : f32, %arg1 : f32) : linalg.yield %arg0 : f32 } -> tensor<?x?xf32>
gets tiled to
%operand = linalg.init_tensor ... : tensor<?x?xf32> %result = scf.for %iv0 = ... iter_args(%arg0 = %operand) { ... }
After bufferization the
linalg.init_tensor
gets converted to amemref.alloc
. To avoid cascading effects due to changes proposed here, with the following IR%result_shape = tensor.create_abstract_tensor [%d0, %d1] : abstract_tensor<?x?xf32> %result = linalg.generic { iterator_types = ["parallel", "parallel"], indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>]} ins(%input : tensor<?xf32>) outs(%result_shape : abstract_tensor<?x?xf32>) { ^bb0(%arg0 : f32, %arg1 : f32) : linalg.yield %arg0 : f32 } -> tensor<?x?xf32>
the Linalg tiling transformation will create a
linalg.init_tensor
forouts
operands of typebuiltin.abstract_tensor
and use that as theiter_args
operand of thescf.for
. This will insulate the downstream from any changes.
After this RFC, all remaining uses of linalg.init_tensor
will be related to their use as a tensor with undefined values. This use can be deprecated by using ideas from this RFC.