Summary
The current implementation of the TOSA-to-Linalg pass lowers element-wise arithmetic operations incorrectly for certain combinations of input tensor shapes. While some issues are simply isolated bugs, others are related with the lack of a well-defined specification for broadcast semantics, especially when dynamic dimensions are involved.
In this document, we identify scenarios for which the current implementation breaks. We then propose an interpretation of the TOSA standard for broadcast semantics when applied to MLIR tensors with dynamic dimensions. Finally, we propose a general TOSA-to-Linalg lowering strategy for unary, binary, and ternary element-wise arithmetic ops.
An implementation for the modifications presented in this document is available for review at https://reviews.llvm.org/D153291.
Current lowering strategy
The existing TOSA-to-Linalg lowering for element-wise arithmetic ops has problems with the way it handles broadcast semantics and dynamic dimensions. Here are some instances of correct and incorrect conversions for different combinations of 2D tensor sizes in the tosa.add op.
Combination <?x?>
, <?x?>
(correct)
TOSA code
func.func @main(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%2 = arith.addf %in, %in_1 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
}
This conversion is correct. It assumes that tensors %arg0 and %arg1 have the same size, and produces undefined behavior otherwise. This conversion does not account for broadcast semantics, i.e., the possibility for a specific dimension to be 1 in one tensor and larger than 1 in the other. We may consider this acceptable behavior - see Section Broadcast semantics for dynamic dimensions below.
Combination <1x?>
, <?x?>
(correct)
TOSA code
func.func @main(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
%dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%0 = tensor.empty(%dim_0, %dim) : tensor<?x?xf32>
%collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %arg1 : tensor<?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%2 = arith.addf %in, %in_1 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
}
The broadcasting semantics of %arg0.dims[0] are honored by this lowering. The broadcast dimension is eliminated through the use of a tensor.collapse_shape op. Then the affine map used to access %arg0 omits d0 on its right-hand side.
Combination <2x3>
, <4x3>
(correct)
TOSA code
func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<4x3xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
fails verification with the following message:
error: 'tosa.add' op operands don't have broadcast-compatible shapes
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<4x3xf32>) -> tensor<?x?xf32>
^
add.tosa.mlir:3:10: note: see current operation: %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<4x3xf32>) -> tensor<?x?xf32>
This is the expected behavior.
Combination <1x5>
, <3x5>
(correct)
TOSA code
func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1x5xf32>, tensor<3x5xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
%0 = tensor.empty() : tensor<3x5xf32>
%collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x5xf32> into tensor<5xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %arg1 : tensor<5xf32>, tensor<3x5xf32>) outs(%0 : tensor<3x5xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<3x5xf32>
%cast = tensor.cast %1 : tensor<3x5xf32> to tensor<?x?xf32>
return %cast : tensor<?x?xf32>
}
}
The broadcasting semantics of static dimension arg0.dims[0] is honored and lowered correctly.
Combination <3x5>
, <3x5>
(correct)
TOSA code
func.func @main(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
%0 = tensor.empty() : tensor<3x5xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<3x5xf32>, tensor<3x5xf32>) outs(%0 : tensor<3x5xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<3x5xf32>
%cast = tensor.cast %1 : tensor<3x5xf32> to tensor<?x?xf32>
return %cast : tensor<?x?xf32>
}
}
The lowering is correct when all input dimensions are static and both input tensors have matching shapes.
Combination <2x?>
, <?x?>
(incorrect)
TOSA code
func.func @main(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c1 : tensor<2x?xf32>
%dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%0 = tensor.empty(%dim_0, %dim) : tensor<?x?xf32>
%collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %arg1 : tensor<?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%2 = arith.addf %in, %in_1 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
}
This lowering is wrong. Elements %arg0[0, 0], %arg0[0, 1], … %arg0[0, n - 1], %arg0[1, 0], %arg0[1, 1], …, %arg0[1, n - 1] are collapsed into %collapsed[0], %collapsed[1], … %collapsed[2n - 1]. The iteration space in linalg.generic is then ambiguous. According to #map, d1 = [0:2n]; according to #map1, d1 = [0:n]. If #map is honored, %arg1 is accessed out of bounds; if #map1 is honored, half of the elements of %collapsed are never accessed.
Combination <2x2>
, <?x?>
(incorrect)
Attempting to convert this TOSA code
func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
produces the following compilation error:
add.tosa.mlir:3:10: error: 'tosa.reshape' op Cannot reshape 4 elements into 1
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
^
add.tosa.mlir:3:10: note: see current operation: %5 = "tosa.reshape"(%arg0) <{new_shape = array<i64>}> : (tensor<2x2xf32>) -> tensor<f32>
// -----// IR Dump After TosaToLinalg Failed (tosa-to-linalg) //----- //
"func.func"() <{function_type = (tensor<2x2xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "main"}> ({
^bb0(%arg0: tensor<2x2xf32>, %arg1: tensor<?x?xf32>):
%0 = "arith.constant"() <{value = 0 : index}> : () -> index
%1 = "tensor.dim"(%arg1, %0) : (tensor<?x?xf32>, index) -> index
%2 = "arith.constant"() <{value = 1 : index}> : () -> index
%3 = "tensor.dim"(%arg1, %2) : (tensor<?x?xf32>, index) -> index
%4 = "tensor.empty"(%1, %3) : (index, index) -> tensor<?x?xf32>
%5 = "tosa.reshape"(%arg0) <{new_shape = array<i64>}> : (tensor<2x2xf32>) -> tensor<f32>
...
The lowering pass is attempting to produce a tosa.reshape op with an invalid target shape.
Combination <?x2>
, <2x?>
(incorrect)
TOSA code
func.func @main(%arg0: tensor<?x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<?x2xf32>, tensor<2x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<?x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<?x?xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x2xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<2x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x2xf32> into tensor<?xf32>
%collapsed_1 = tensor.collapse_shape %arg1 [[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %collapsed_1 : tensor<?xf32>, tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%2 = arith.addf %in, %in_2 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
}
If the use of tosa.add is valid, it can be deduced that both %arg0 and %arg1 have dimensions 2x2. This is the shape of the output tensor %0 returned by tensor.empty. According to #map2, the iteration space for linalg.generic is then d0 = [0:2] and d1 = [0:2]. However, according to #map, d0 = [0:collapsed.size] = [0:4]. According to #map1, d1 = [0:collapsed_1.size] = [0:4]. The iteration space for linalg.generic is therefore ambiguous and its use leads to undefined behavior.
Enforcement of rank equality
The TOSA spec requires input tensors in element-wise operations to have the same rank. Op verifiers are currently not enforcing this requirement, which leads to the following example of a tosa.add
op to be parsed and printed correctly.
%result = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
Such dangerous lenience has enabled the presence of a current bug in the Tensorflow project, motivated by the fact that similar element-wise operations in the tf
(Tensorflow) and tfl
(Tensorflow Lite) dialect support implicit rank expansion. The above tosa.add
op could then be the result of applying a TFL-to-TOSA conversion pass on the following correct tfl.add op:
%result = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} :
(tensor<3x4xf32>, tensor<2x3x4xf32>) ->
tensor<2x3x4xf32>
This RFC proposes to add trait AllRanksMatchIfKnown
to element-wise TOSA operations in order to enforce equal rank for all of the operation’s operands and its result, if ranked. The RFC hints at simultaneous fixes in TF/TFL lowering passes in the Tensorflow repository that will adjust to this requirement. Once this implementation lands in LLVM, we will no longer need to worry about the possibility of encountering incompatible ranks when lowering an element-wise TOSA op.
For now, our TOSA-to-Linalg pass currently aligns with the existing verification logic and supports different ranks. The pass relies on an initial rank equalizing step that adds additional dimensions of size 1 to the left of input tensor shapes until they all reach a common rank. Operation tensor.expand_shape
is used for this purpose. In the example above, 2D tensor %arg0
of type tensor<3x4xf32>
is first expanded into a 3D tensor as follows:
%reshaped_arg0 = tensor.expand_shape %arg0 [[0, 1], [2]] :
tensor<3x4xf32> into tensor<1x3x4xf32>
The Broadcastable
trait (a.k.a. ResultsBroadcastableShape
)
Description
The Broadcastable
op trait is currently used only by TOSA element-wise ops in the LLVM repository. It is also used in the Tensorflow repository for the definition of element-wise ops with broadcasting semantics in the tf
and tfl
dialects.
The trait serves three main purposes: input shape verification, result shape inference, and result shape verification. The properties and functionality provided by this trait are described in detail in a new documentation file at mlir/docs/Traits/Broadcastable.md
in the proposed patch (see Phabricator link above). While most of the file content is a formalization of previously existing functionality provided by the Broadcastable
op trait, some of its features have been modified, as described in the following sections.
Modification 1: Removing broadcast semantics for result dimensions
The current implementation of the Broadcastable
trait assigns broadcast semantics to the inferred result type. For example, operation
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<5xf32>
is considered legal. The inferred result type is tensor<1xf32>
, while the actual return type is tensor<5xf32>
. The expected behavior is that the single result element is broadcast to the 5 positions of the actual result tensor. This implicit behavior adds a level of obscurity to the operation semantics, and is therefore disabled in the proposed implementation.
A programmer may attain a similar outcome by leveraging broadcast semantics of input dimensions. If the desired behavior is using the scalar result of the tosa.add operation above in a subsequent 5-element tensor operation, such consumer may look like this:
%1 = "tosa.sub"(%arg2, %0): tensor<5xf32>, tensor<1xf32>) -> tensor<5xf32>
Instead of having the original tosa.add
producer op materialize a 5-element tensor, the tosa.sub consumer op is implicitly broadcasting the scalar value for every subtraction.
If broadcasting a scalar value into a tensor is deemed a desirable stand-alone feature by the TOSA community, it might be worth considering the introduction of a dedicated tosa.broadcast
operation for this purpose.
Modification 2: Forbidding implicit dynamic-to-static dimension cast in result dimensions
The current implementation of the Broadcastable
trait allows for an inferred result dimension to be dynamic while the actual result dimension is given a static size. The following op would then be considered legal:
%0 = "tosa.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<5xf32>
The inferred result type for this op is tensor<?xf32>
as a consequence of all input dimensions being dynamic. If the result is known to have 5 elements at compile time, the programmer must necessarily be aware of either input dimension to be of size 5, too. It is therefore counterintuitive to allow for all input dimensions to be dynamic in this case. The proposed implementation disallows an implicit dynamic-to-static conversion for inferred results.
List of operations
This is the full list of element-wise TOSA operations affected by the changes proposed in this document. All these operations now rely on the proposed version of the Broadcastable
op interface for input and result shape verification. The common proposed lowering strategy is now applicable to all these operations for any valid combination of input and result shapes, excluding unranked tensors.
-
Unary:
tosa.abs
,tosa.bitwise_not
,tosa.cast
,tosa.ceil
,tosa.clamp
,tosa.clz
,tosa.erf
,tosa.exp
,tosa.floor
,tosa.log
,tosa.logical_not
,tosa.negate
,tosa.reciprocal
,tosa.rsqrt
,tosa.sigmoid
,tosa.tanh
-
Binary:
tosa.add
,tosa.arithmetic_right_shift
,tosa.bitwise_and
,tosa.bitwise_or
,tosa.bitwise_xor
,tosa.div
,tosa.equal
,tosa.greater
,tosa.greater_equal
,tosa.logical_and
,tosa.logical_left_shift
,tosa.logical_or
,tosa.logical_right_shift
,tosa.logical_xor
,tosa.maximum
,tosa.minimum
,tosa.mul
,tosa.pow
,tosa.sub
-
Ternary:
tosa.select
Lowering strategy
The following examples illustrate the proposed lowering process for different combinations of input and output result types, and for operations with a different number of arguments. The lowering strategy is aimed at supporting any legal form of an element-wise TOSA op according to the verification criterion enforced by the Broadcastable
op trait, with the exception of unranked tensors.
Static dimensions of equal size
The simplest scenario is one in which all dimensions are static and match in size between input operands and result. The following binary TOSA op
func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
converts to
#map = affine_map<(d0) -> (d0)>
func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = tensor.empty() : tensor<3xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]
} ins(%arg0, %arg1 : tensor<3xf32>, tensor<3xf32>) outs(%0 : tensor<3xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<3xf32>
return %1 : tensor<3xf32>
}
The generated code has the following components:
- A set of affine maps define the iteration space and the mapping between input and output values. The set contains one affine map per input operand plus one additional affine map for the result. Identical affine maps are combined into a single declaration in the code. Here,
#map
is shared by all input operands and the result. - A
tensor.empty
operation creates an output tensor whose shape is obtained by the result inference logic provided by theBroadcastable
op trait. - A
linalg.generic
op specifies the arithmetic computation yielding individual output values.
Dynamic dimension in unary operation
TOSA code
func.func @test_abs_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
converts to
#map = affine_map<(d0) -> (d0)>
func.func @test_abs_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// Create output tensor
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
%0 = tensor.empty(%dim) : tensor<?xf32>
// Perform element-wise computation
%1 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]
} ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%2 = math.absf %in : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
return %1 : tensor<?xf32>
}
The inferred result type for this tosa.abs
op is tensor<?xf32>
. This type is used by the tensor.empty
op for the creation of the output tensor. The occurrence of a dynamic dimension in the inferred output type forces us to provide additional information about its runtime size to the tensor.emtpy
op. This work is done by an additional tensor.dim
op, which infers the result size from the size of %arg0
.
Inferred vs actual result type mismatch
The tensor type used in the tensor.empty
and later in the linalg.generic
ops is purely inferred from input operands, and does not necessarily match the actual result type used in the original TOSA op. In this example, TOSA code
func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor<?xf32> {
%0 = "tosa.abs"(%arg0) : (tensor<5xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
converts to
#map = affine_map<(d0) -> (d0)>
func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor<?xf32> {
// Perform element-wise computation
%0 = tensor.empty() : tensor<5xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]
} ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5xf32>) {
^bb0(%in: f32, %out: f32):
%2 = math.absf %in : f32
linalg.yield %2 : f32
} -> tensor<5xf32>
// Cast result
%cast = tensor.cast %1 : tensor<5xf32> to tensor<?xf32>
return %cast : tensor<?xf32>
}
The result type inference performed by the Broadcastable
trait deduces an output tensor of type tensor<5xf32>
, while the original op uses the more general type tensor<?xf32>
as its result type. The result type compatibility rules specified by the Broadcastable
trait tolerate this combination, and the lowering pass resolves this type difference by introducing an additional tensor.cast
op.
Fully static broadcasting
The tosa.add
op below uses two 1D tensors with static dimensions as input operands. According to the broadcast semantics specified by the Broadcastable
trait, the dimension of size 1 in %arg0
must be broadcast to all elements of a tensor matching the size of %arg1
before the element-wise computation can proceed. The following code
func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
converts to
#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = tensor.empty() : tensor<3xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map1, #map1],
iterator_types = ["parallel"]
} ins(%arg0, %arg1 : tensor<1xf32>, tensor<3xf32>) outs(%0 : tensor<3xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<3xf32>
return %1 : tensor<3xf32>
}
Instead of explicitly expanding %arg0
into a tensor of type tensor<3xf32>
, the lowering strategy relies on the use of two different affine maps (#map
and #map1
) in the linalg.generic
op. The former indicates that every execution of the linalg.generic
body should access element 0 of tensor %arg0
, while the latter accesses all 3 elements of %arg1
, effectively honoring the desired broadcast behavior.
Dynamic-to-static broadcasting
The tosa.add
op below has two 1D tensors as input operands, where one has a static dimension greater than 1 and the other has a dynamic dimension. The lowering pass must consider the following runtime sizes for the dynamic dimension of %arg1
:
- If it is 1, broadcast semantics apply. Tensor
%arg1
must be expanded into a 5-element tensor before the element-wise computation can proceed. - If it is 5, the element-wise computation must be performed directly on the input operands.
- If it has any other size, the behavior is undefined.
The following code
func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor<?xf32>) -> tensor<5xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<5xf32>, tensor<?xf32>) -> tensor<5xf32>
return %0 : tensor<5xf32>
}
converts to
#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor<?xf32>) -> tensor<5xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
// Broadcast dimension 0 of %arg0
%dim = tensor.dim %arg1, %c0 : tensor<?xf32>
%0 = arith.cmpi eq, %dim, %c1 : index
%1 = scf.if %0 -> (tensor<?xf32>) {
%4 = tensor.empty() : tensor<5xf32>
%5 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel"]
} ins(%arg1 : tensor<?xf32>) outs(%4 : tensor<5xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<5xf32>
%cast = tensor.cast %5 : tensor<5xf32> to tensor<?xf32>
scf.yield %cast : tensor<?xf32>
} else {
scf.yield %arg1 : tensor<?xf32>
}
// Execute element-wise operation
%2 = tensor.empty() : tensor<5xf32>
%3 = linalg.generic {
indexing_maps = [#map1, #map1, #map1],
iterator_types = ["parallel"]
} ins(%arg0, %1 : tensor<5xf32>, tensor<?xf32>) outs(%2 : tensor<5xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<5xf32>
return %3 : tensor<5xf32>
}
The lowering pass introduces a new section in the generated code whose purpose is broadcasting dimension 0 of %arg1
if it is determined to have a runtime size of 1. Operations tensor.dim
and arith.cmpi
evaluate this condition, while scf.if
introduces the appropriate control flow divergence based on the outcome. If broadcasting is confirmed, an additional linalg.generic
op replicates the single element in %arg1
5 times. If broadcasting is not applicable, tensor %arg1
is used as is.
Static-to-dynamic broadcasting
The tosa.add
op below uses two 1D tensors as inputs where one has a static dimension of size 1 and the other has a dynamic dimension. One might conceive two distinct execution scenarios: one in which the dynamic dimension has a runtime size of 1, and one in which it is greater than 1. The former involves a direct element-wise computation, while the latter needs to apply a broadcasting action beforehand. However, both scenarios may be expressed as one in which the single element in the static dimension is broadcast to every position of the dynamic dimension, whether such dimension is of size 1 or more.
The following code
func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
converts to
#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// Create output tensor
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg1, %c0 : tensor<?xf32>
%0 = tensor.empty(%dim) : tensor<?xf32>
// Perform element-wise computation
%1 = linalg.generic {
indexing_maps = [#map, #map1, #map1],
iterator_types = ["parallel"]
} ins(%arg0, %arg1 : tensor<1xf32>, tensor<?xf32>) outs(%0 : tensor<?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
return %1 : tensor<?xf32>
}
The resulting code does not require the introduction of control flow, as it behaves correctly for any runtime size of %arg0
. The generated code is similar to the fully static broadcasting example, except the size of the output tensor must be computed at runtime with an additional tensor.dim
op.
Fully dynamic broadcasting
The tosa.add
op below uses two 1D tensors as input operands, both of which have dynamic dimensions. These are the possible execution scenarios to consider:
- In the straightforward case that both operands have the same runtime size, the element-wise computation is carried out with the original tensors.
- If the size of
%arg0
is 1 and the size of%arg1
is greater than 1, broadcast semantics apply to%arg0
. Its single element must be replicated as many times as elements are present in%arg1
before the element-wise computation can proceed. - Conversely, if the size of
%arg1
is 1 and the size of%arg0
is greater than 1, broadcast semantics apply to%arg1
. - If both operands have a runtime size greater than 1 but these sizes are not equal, the operation causes undefined behavior.
The following code
func.func @test_add_1d_all_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
converts to
#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_all_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// Calculate maximum dimension size
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
%dim_0 = tensor.dim %arg1, %c0 : tensor<?xf32>
%0 = arith.maxui %dim, %dim_0 : index
// Broadcast dimension 0 of %arg0
%c1 = arith.constant 1 : index
%dim_1 = tensor.dim %arg0, %c0 : tensor<?xf32>
%1 = arith.cmpi eq, %dim_1, %c1 : index
%2 = scf.if %1 -> (tensor<?xf32>) {
%7 = tensor.empty(%0) : tensor<?xf32>
%8 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel"]
} ins(%arg0 : tensor<?xf32>) outs(%7 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?xf32>
scf.yield %8 : tensor<?xf32>
} else {
scf.yield %arg0 : tensor<?xf32>
}
// Broadcast dimension 0 of %arg1
%dim_2 = tensor.dim %arg1, %c0 : tensor<?xf32>
%3 = arith.cmpi eq, %dim_2, %c1 : index
%4 = scf.if %3 -> (tensor<?xf32>) {
%7 = tensor.empty(%0) : tensor<?xf32>
%8 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel"]
} ins(%arg1 : tensor<?xf32>) outs(%7 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?xf32>
scf.yield %8 : tensor<?xf32>
} else {
scf.yield %arg1 : tensor<?xf32>
}
// Perform element-wise computation
%5 = tensor.empty(%0) : tensor<?xf32>
%6 = linalg.generic {
indexing_maps = [#map1, #map1, #map1],
iterator_types = ["parallel"]
} ins(%2, %4 : tensor<?xf32>, tensor<?xf32>) outs(%5 : tensor<?xf32>) {
^bb0(%in: f32, %in_3: f32, %out: f32):
%7 = arith.addf %in, %in_3 : f32
linalg.yield %7 : f32
} -> tensor<?xf32>
return %6 : tensor<?xf32>
}
The strategy here consists in first calculating the maximum runtime size for dimension 0 of both input operands. Each operand is then broadcast to the maximum runtime size if necessary. After both operands are visited, they are guaranteed to have matching runtime dimensions, and the element-wise computation may proceed.
Multi-dimensional fully dynamic broadcasting
The previous example is extended here by using 2D tensors as input operands with all dynamic dimensions. The following code
func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0, 0)>
func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// Calculate maximum dimension 0
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%0 = arith.maxui %dim, %dim_0 : index
// Calculate maximum dimension 1
%c1 = arith.constant 1 : index
%dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%dim_2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%1 = arith.maxui %dim_1, %dim_2 : index
// Broadcast dimension 0 of %arg0
%dim_3 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = arith.cmpi eq, %dim_3, %c1 : index
%3 = scf.if %2 -> (tensor<?x?xf32>) {
%dim_7 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
%13 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%arg0 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?xf32>
scf.yield %13 : tensor<?x?xf32>
} else {
scf.yield %arg0 : tensor<?x?xf32>
}
// Broadcast dimension 1 of %arg0
%dim_4 = tensor.dim %3, %c1 : tensor<?x?xf32>
%4 = arith.cmpi eq, %dim_4, %c1 : index
%5 = scf.if %4 -> (tensor<?x?xf32>) {
%dim_7 = tensor.dim %3, %c0 : tensor<?x?xf32>
%12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
%13 = linalg.generic {
indexing_maps = [#map2, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%3 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?xf32>
scf.yield %13 : tensor<?x?xf32>
} else {
scf.yield %3 : tensor<?x?xf32>
}
// Broadcast dimension 0 of %arg1
%dim_5 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%6 = arith.cmpi eq, %dim_5, %c1 : index
%7 = scf.if %6 -> (tensor<?x?xf32>) {
%dim_7 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
%13 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%arg1 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?xf32>
scf.yield %13 : tensor<?x?xf32>
} else {
scf.yield %arg1 : tensor<?x?xf32>
}
// Broadcast dimension 1 of %arg1
%dim_6 = tensor.dim %7, %c1 : tensor<?x?xf32>
%8 = arith.cmpi eq, %dim_6, %c1 : index
%9 = scf.if %8 -> (tensor<?x?xf32>) {
%dim_7 = tensor.dim %7, %c0 : tensor<?x?xf32>
%12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
%13 = linalg.generic {
indexing_maps = [#map2, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%7 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?xf32>
scf.yield %13 : tensor<?x?xf32>
} else {
scf.yield %7 : tensor<?x?xf32>
}
// Perform element-wise computation
%10 = tensor.empty(%0, %1) : tensor<?x?xf32>
%11 = linalg.generic {
indexing_maps = [#map1, #map1, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%5, %9 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%10 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_7: f32, %out: f32):
%12 = arith.addf %in, %in_7 : f32
linalg.yield %12 : f32
} -> tensor<?x?xf32>
return %11 : tensor<?x?xf32>
}
First, the computation of the maximum dimension size is done separately for dimensions 0 and 1. Then, conditional broadcasting actions are carried out for all input operands and dimensions: dimension 0 of %arg0
, dimension 1 of %arg0
, dimension 0 of %arg1
, and dimension 1 of %arg1
. Once all runtime sizes are guaranteed to be equal, the element-wise computation proceeds.
A subtle behavior to keep under control in fully dynamic multi-dimensional broadcasts is that every broadcast operation must affect exactly one dimension in one operand, while making sure that the sizes of all other dynamic dimensions in that operand are preserved in their current intermediate state. This is reflected in the fact that every tensor.empty
op preceding a broadcasting linalg.generic
op does enforce the new maximum dimension size for the broadcast dimension, but preserves the size of its neighboring dimension by first querying it with a tensor.dim
op.
Zero-dimensional tensors
The tosa.add
below op uses zero-dimensional tensors in its input operands and result, while the generated code behaves in an equivalent manner to a scalar binary operation. The following code
func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
converts to
#map = affine_map<() -> ()>
func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = tensor.empty() : tensor<f32>
%1 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = []
} ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<f32>
return %1 : tensor<f32>
}
Rank expansion
The tosa.add
op below has input tensors with different ranks. A rank equalization stage must first reshape %arg0
of type tensor<3x4xf32>
into a 3D tensor of type tensor<1x3x4xf32>
. Dimension 0 of size 1 in the expanded 3D tensor must then be broadcast. The expanded, broadcast version of %arg0
is finally of type tensor<2x3x4xf32>
, which is ready for the element-wise computation with %arg1
.
The following code
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
converts to
#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
// Expand shape of %arg0 to match %arg1
%expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>
// Perform element-wise operation
%0 = tensor.empty() : tensor<2x3x4xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%expanded, %arg1 : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%0 : tensor<2x3x4xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<2x3x4xf32>
return %1 : tensor<2x3x4xf32>
}
In the generated code, rank equalization is carried out by an additional tensor.expand_shape
op. The broadcast action on the new 1-sized dimension of %arg0
is accomplished through the use of modified affine map #map
, whose left-most affine expression is set to constant 0.
Ternary operation
This example illustrates the lowering of the tosa.select
op, which currently acts as the sole representative of ternary element-wise ops in the TOSA dialect. The following code
func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
converts to
#map = affine_map<(d0, d1) -> (d0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {
// Compute maximum dimension 1 among operands
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c1 : tensor<2x?xi1>
%dim_0 = tensor.dim %arg1, %c1 : tensor<2x?xf32>
%0 = arith.maxui %dim, %dim_0 : index
%dim_1 = tensor.dim %arg2, %c1 : tensor<2x?xf32>
%1 = arith.maxui %0, %dim_1 : index
// Broadcast dimension 1 of %arg0
%dim_2 = tensor.dim %arg0, %c1 : tensor<2x?xi1>
%2 = arith.cmpi eq, %dim_2, %c1 : index
%3 = scf.if %2 -> (tensor<2x?xi1>) {
%10 = tensor.empty(%1) : tensor<2x?xi1>
%11 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%arg0 : tensor<2x?xi1>) outs(%10 : tensor<2x?xi1>) {
^bb0(%in: i1, %out: i1):
linalg.yield %in : i1
} -> tensor<2x?xi1>
scf.yield %11 : tensor<2x?xi1>
} else {
scf.yield %arg0 : tensor<2x?xi1>
}
// Broadcast dimension 1 of %arg1
%dim_3 = tensor.dim %arg1, %c1 : tensor<2x?xf32>
%4 = arith.cmpi eq, %dim_3, %c1 : index
%5 = scf.if %4 -> (tensor<2x?xf32>) {
%10 = tensor.empty(%1) : tensor<2x?xf32>
%11 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%arg1 : tensor<2x?xf32>) outs(%10 : tensor<2x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x?xf32>
scf.yield %11 : tensor<2x?xf32>
} else {
scf.yield %arg1 : tensor<2x?xf32>
}
// Broadcast dimension 1 of %arg2
%dim_4 = tensor.dim %arg2, %c1 : tensor<2x?xf32>
%6 = arith.cmpi eq, %dim_4, %c1 : index
%7 = scf.if %6 -> (tensor<2x?xf32>) {
%10 = tensor.empty(%1) : tensor<2x?xf32>
%11 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%arg2 : tensor<2x?xf32>) outs(%10 : tensor<2x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x?xf32>
scf.yield %11 : tensor<2x?xf32>
} else {
scf.yield %arg2 : tensor<2x?xf32>
}
// Perform element-wise computation
%8 = tensor.empty(%1) : tensor<2x?xf32>
%9 = linalg.generic {
indexing_maps = [#map1, #map1, #map1, #map1],
iterator_types = ["parallel", "parallel"]
} ins(%3, %5, %7 : tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) outs(%8 : tensor<2x?xf32>) {
^bb0(%in: i1, %in_5: f32, %in_6: f32, %out: f32):
%10 = arith.select %in, %in_5, %in_6 : f32
linalg.yield %10 : f32
} -> tensor<2x?xf32>
return %9 : tensor<2x?xf32>
}
The generated code first obtains the maximum runtime size for the dynamic dimension (dimension 1) among the 3 input operands. Then, each input is individually broadcast into that maximum size if its dynamic dimension is identified to have a runtime size of 1. After all tensors are guaranteed to have equal sizes, the element-wise computation is performed.
Additional notes
The following issues have been addressed in the current version of this RFC since it was originally posted:
- The RFC originally suggested to remove broadcast semantics from input dynamic dimensions for simplicity. Previous feedback suggested that it is important to support this feature, which is now reflected in the current proposal.
The following issues remain open for discussion and may involve further modifications in this RFC:
- The proposed specification for the
Broadcastable
op trait removes broadcast semantics for the inferred result for simplicity (labeled as Modification 1 in the RFC). Is this change acceptable? - The proposed specification for the
Broadcastable
op trait forbids implicit dynamic-to-static cast in result dimensions (labeled Modification 2 in the RFC). Is this change acceptable? - The existing implementation includes two locations where result dimensions are inferred for element-wise ops. One is function
OpTrait::util::getBroadcastedShape()
, invoked during op verification for ops with theBroadcastable
trait. The other isOP::inferReturnTypeComponents()
. It seems there are opportunities to factor out some code here.