[RFC] TOSA-to-Linalg lowering of element-wise ops

Summary

The current implementation of the TOSA-to-Linalg pass lowers element-wise arithmetic operations incorrectly for certain combinations of input tensor shapes. While some issues are simply isolated bugs, others are related with the lack of a well-defined specification for broadcast semantics, especially when dynamic dimensions are involved.

In this document, we identify scenarios for which the current implementation breaks. We then propose an interpretation of the TOSA standard for broadcast semantics when applied to MLIR tensors with dynamic dimensions. Finally, we propose a general TOSA-to-Linalg lowering strategy for unary, binary, and ternary element-wise arithmetic ops.

An implementation for the modifications presented in this document is available for review at https://reviews.llvm.org/D153291.

Current lowering strategy

The existing TOSA-to-Linalg lowering for element-wise arithmetic ops has problems with the way it handles broadcast semantics and dynamic dimensions. Here are some instances of correct and incorrect conversions for different combinations of 2D tensor sizes in the tosa.add op.

Combination <?x?>, <?x?> (correct)

TOSA code

func.func @main(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d0, d1)>

module {
  func.func @main(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}  ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %2 = arith.addf %in, %in_1 : f32
      linalg.yield %2 : f32
    } -> tensor<?x?xf32>
    return %1 : tensor<?x?xf32>
  }
}

This conversion is correct. It assumes that tensors %arg0 and %arg1 have the same size, and produces undefined behavior otherwise. This conversion does not account for broadcast semantics, i.e., the possibility for a specific dimension to be 1 in one tensor and larger than 1 in the other. We may consider this acceptable behavior - see Section Broadcast semantics for dynamic dimensions below.

Combination <1x?>, <?x?> (correct)

TOSA code

func.func @main(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>

module {
  func.func @main(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
    %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
    %0 = tensor.empty(%dim_0, %dim) : tensor<?x?xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %arg1 : tensor<?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %2 = arith.addf %in, %in_1 : f32
      linalg.yield %2 : f32
    } -> tensor<?x?xf32>
    return %1 : tensor<?x?xf32>
  }
}

The broadcasting semantics of %arg0.dims[0] are honored by this lowering. The broadcast dimension is eliminated through the use of a tensor.collapse_shape op. Then the affine map used to access %arg0 omits d0 on its right-hand side.

Combination <2x3>, <4x3> (correct)

TOSA code

func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<4x3xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

fails verification with the following message:

error: 'tosa.add' op operands don't have broadcast-compatible shapes
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<4x3xf32>) -> tensor<?x?xf32>
^
add.tosa.mlir:3:10: note: see current operation: %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<4x3xf32>) -> tensor<?x?xf32>

This is the expected behavior.

Combination <1x5>, <3x5> (correct)

TOSA code

func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x5xf32>, tensor<3x5xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
    %0 = tensor.empty() : tensor<3x5xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x5xf32> into tensor<5xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %arg1 : tensor<5xf32>, tensor<3x5xf32>) outs(%0 : tensor<3x5xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %2 = arith.addf %in, %in_0 : f32
      linalg.yield %2 : f32
    } -> tensor<3x5xf32>
    %cast = tensor.cast %1 : tensor<3x5xf32> to tensor<?x?xf32>
    return %cast : tensor<?x?xf32>
  }
}

The broadcasting semantics of static dimension arg0.dims[0] is honored and lowered correctly.

Combination <3x5>, <3x5> (correct)

TOSA code

func.func @main(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d0, d1)>

module {
  func.func @main(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<?x?xf32> {
    %0 = tensor.empty() : tensor<3x5xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<3x5xf32>, tensor<3x5xf32>) outs(%0 : tensor<3x5xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %2 = arith.addf %in, %in_0 : f32
      linalg.yield %2 : f32
    } -> tensor<3x5xf32>
    %cast = tensor.cast %1 : tensor<3x5xf32> to tensor<?x?xf32>
    return %cast : tensor<?x?xf32>
  }
}

The lowering is correct when all input dimensions are static and both input tensors have matching shapes.

Combination <2x?>, <?x?> (incorrect)

TOSA code

func.func @main(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>

module {
  func.func @main(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<2x?xf32>
    %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
    %0 = tensor.empty(%dim_0, %dim) : tensor<?x?xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %arg1 : tensor<?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %2 = arith.addf %in, %in_1 : f32
      linalg.yield %2 : f32
    } -> tensor<?x?xf32>
    return %1 : tensor<?x?xf32>
  }
}

This lowering is wrong. Elements %arg0[0, 0], %arg0[0, 1], … %arg0[0, n - 1], %arg0[1, 0], %arg0[1, 1], …, %arg0[1, n - 1] are collapsed into %collapsed[0], %collapsed[1], … %collapsed[2n - 1]. The iteration space in linalg.generic is then ambiguous. According to #map, d1 = [0:2n]; according to #map1, d1 = [0:n]. If #map is honored, %arg1 is accessed out of bounds; if #map1 is honored, half of the elements of %collapsed are never accessed.

Combination <2x2>, <?x?> (incorrect)

Attempting to convert this TOSA code

func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

produces the following compilation error:

add.tosa.mlir:3:10: error: 'tosa.reshape' op Cannot reshape 4 elements into 1
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
^
add.tosa.mlir:3:10: note: see current operation: %5 = "tosa.reshape"(%arg0) <{new_shape = array<i64>}> : (tensor<2x2xf32>) -> tensor<f32>

// -----// IR Dump After TosaToLinalg Failed (tosa-to-linalg) //----- //

"func.func"() <{function_type = (tensor<2x2xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "main"}> ({
^bb0(%arg0: tensor<2x2xf32>, %arg1: tensor<?x?xf32>):
  %0 = "arith.constant"() <{value = 0 : index}> : () -> index
  %1 = "tensor.dim"(%arg1, %0) : (tensor<?x?xf32>, index) -> index
  %2 = "arith.constant"() <{value = 1 : index}> : () -> index
  %3 = "tensor.dim"(%arg1, %2) : (tensor<?x?xf32>, index) -> index
  %4 = "tensor.empty"(%1, %3) : (index, index) -> tensor<?x?xf32>
  %5 = "tosa.reshape"(%arg0) <{new_shape = array<i64>}> : (tensor<2x2xf32>) -> tensor<f32>
  ...

The lowering pass is attempting to produce a tosa.reshape op with an invalid target shape.

Combination <?x2>, <2x?> (incorrect)

TOSA code

func.func @main(%arg0: tensor<?x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?x2xf32>, tensor<2x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>

module {
  func.func @main(%arg0: tensor<?x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<?x?xf32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x2xf32>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<2x?xf32>
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x2xf32> into tensor<?xf32>
    %collapsed_1 = tensor.collapse_shape %arg1 [[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %collapsed_1 : tensor<?xf32>, tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %2 = arith.addf %in, %in_2 : f32
      linalg.yield %2 : f32
    } -> tensor<?x?xf32>
    return %1 : tensor<?x?xf32>
  }
}

If the use of tosa.add is valid, it can be deduced that both %arg0 and %arg1 have dimensions 2x2. This is the shape of the output tensor %0 returned by tensor.empty. According to #map2, the iteration space for linalg.generic is then d0 = [0:2] and d1 = [0:2]. However, according to #map, d0 = [0:collapsed.size] = [0:4]. According to #map1, d1 = [0:collapsed_1.size] = [0:4]. The iteration space for linalg.generic is therefore ambiguous and its use leads to undefined behavior.

Enforcement of rank equality

The TOSA spec requires input tensors in element-wise operations to have the same rank. Op verifiers are currently not enforcing this requirement, which leads to the following example of a tosa.add op to be parsed and printed correctly.

%result = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>

Such dangerous lenience has enabled the presence of a current bug in the Tensorflow project, motivated by the fact that similar element-wise operations in the tf (Tensorflow) and tfl (Tensorflow Lite) dialect support implicit rank expansion. The above tosa.add op could then be the result of applying a TFL-to-TOSA conversion pass on the following correct tfl.add op:

%result = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} :
    (tensor<3x4xf32>, tensor<2x3x4xf32>) ->
    tensor<2x3x4xf32>

This RFC proposes to add trait AllRanksMatchIfKnown to element-wise TOSA operations in order to enforce equal rank for all of the operation’s operands and its result, if ranked. The RFC hints at simultaneous fixes in TF/TFL lowering passes in the Tensorflow repository that will adjust to this requirement. Once this implementation lands in LLVM, we will no longer need to worry about the possibility of encountering incompatible ranks when lowering an element-wise TOSA op.

For now, our TOSA-to-Linalg pass currently aligns with the existing verification logic and supports different ranks. The pass relies on an initial rank equalizing step that adds additional dimensions of size 1 to the left of input tensor shapes until they all reach a common rank. Operation tensor.expand_shape is used for this purpose. In the example above, 2D tensor %arg0 of type tensor<3x4xf32> is first expanded into a 3D tensor as follows:

%reshaped_arg0 = tensor.expand_shape %arg0 [[0, 1], [2]] :
    tensor<3x4xf32> into tensor<1x3x4xf32>

The Broadcastable trait (a.k.a. ResultsBroadcastableShape)

Description

The Broadcastable op trait is currently used only by TOSA element-wise ops in the LLVM repository. It is also used in the Tensorflow repository for the definition of element-wise ops with broadcasting semantics in the tf and tfl dialects.

The trait serves three main purposes: input shape verification, result shape inference, and result shape verification. The properties and functionality provided by this trait are described in detail in a new documentation file at mlir/docs/Traits/Broadcastable.md in the proposed patch (see Phabricator link above). While most of the file content is a formalization of previously existing functionality provided by the Broadcastable op trait, some of its features have been modified, as described in the following sections.

Modification 1: Removing broadcast semantics for result dimensions

The current implementation of the Broadcastable trait assigns broadcast semantics to the inferred result type. For example, operation

%0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<5xf32>

is considered legal. The inferred result type is tensor<1xf32>, while the actual return type is tensor<5xf32>. The expected behavior is that the single result element is broadcast to the 5 positions of the actual result tensor. This implicit behavior adds a level of obscurity to the operation semantics, and is therefore disabled in the proposed implementation.

A programmer may attain a similar outcome by leveraging broadcast semantics of input dimensions. If the desired behavior is using the scalar result of the tosa.add operation above in a subsequent 5-element tensor operation, such consumer may look like this:

%1 = "tosa.sub"(%arg2, %0): tensor<5xf32>, tensor<1xf32>) -> tensor<5xf32>

Instead of having the original tosa.add producer op materialize a 5-element tensor, the tosa.sub consumer op is implicitly broadcasting the scalar value for every subtraction.

If broadcasting a scalar value into a tensor is deemed a desirable stand-alone feature by the TOSA community, it might be worth considering the introduction of a dedicated tosa.broadcast operation for this purpose.

Modification 2: Forbidding implicit dynamic-to-static dimension cast in result dimensions

The current implementation of the Broadcastable trait allows for an inferred result dimension to be dynamic while the actual result dimension is given a static size. The following op would then be considered legal:

%0 = "tosa.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<5xf32>

The inferred result type for this op is tensor<?xf32> as a consequence of all input dimensions being dynamic. If the result is known to have 5 elements at compile time, the programmer must necessarily be aware of either input dimension to be of size 5, too. It is therefore counterintuitive to allow for all input dimensions to be dynamic in this case. The proposed implementation disallows an implicit dynamic-to-static conversion for inferred results.

List of operations

This is the full list of element-wise TOSA operations affected by the changes proposed in this document. All these operations now rely on the proposed version of the Broadcastable op interface for input and result shape verification. The common proposed lowering strategy is now applicable to all these operations for any valid combination of input and result shapes, excluding unranked tensors.

  • Unary: tosa.abs, tosa.bitwise_not, tosa.cast, tosa.ceil, tosa.clamp, tosa.clz, tosa.erf, tosa.exp, tosa.floor, tosa.log, tosa.logical_not, tosa.negate, tosa.reciprocal, tosa.rsqrt, tosa.sigmoid, tosa.tanh

  • Binary: tosa.add, tosa.arithmetic_right_shift, tosa.bitwise_and, tosa.bitwise_or, tosa.bitwise_xor, tosa.div, tosa.equal, tosa.greater, tosa.greater_equal, tosa.logical_and, tosa.logical_left_shift, tosa.logical_or, tosa.logical_right_shift, tosa.logical_xor, tosa.maximum, tosa.minimum, tosa.mul, tosa.pow, tosa.sub

  • Ternary: tosa.select

Lowering strategy

The following examples illustrate the proposed lowering process for different combinations of input and output result types, and for operations with a different number of arguments. The lowering strategy is aimed at supporting any legal form of an element-wise TOSA op according to the verification criterion enforced by the Broadcastable op trait, with the exception of unranked tensors.

Static dimensions of equal size

The simplest scenario is one in which all dimensions are static and match in size between input operands and result. The following binary TOSA op

func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
  return %0 : tensor<3xf32>
}

converts to

#map = affine_map<(d0) -> (d0)>
func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
  %0 = tensor.empty() : tensor<3xf32>
  %1 = linalg.generic {
    indexing_maps = [#map, #map, #map],
    iterator_types = ["parallel"]
  } ins(%arg0, %arg1 : tensor<3xf32>, tensor<3xf32>) outs(%0 : tensor<3xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %2 = arith.addf %in, %in_0 : f32
    linalg.yield %2 : f32
  } -> tensor<3xf32>
  return %1 : tensor<3xf32>
}

The generated code has the following components:

  • A set of affine maps define the iteration space and the mapping between input and output values. The set contains one affine map per input operand plus one additional affine map for the result. Identical affine maps are combined into a single declaration in the code. Here, #map is shared by all input operands and the result.
  • A tensor.empty operation creates an output tensor whose shape is obtained by the result inference logic provided by the Broadcastable op trait.
  • A linalg.generic op specifies the arithmetic computation yielding individual output values.

Dynamic dimension in unary operation

TOSA code

func.func @test_abs_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

converts to

#map = affine_map<(d0) -> (d0)>
func.func @test_abs_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {

  // Create output tensor
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
  %0 = tensor.empty(%dim) : tensor<?xf32>

  // Perform element-wise computation
  %1 = linalg.generic {
    indexing_maps = [#map, #map],
    iterator_types = ["parallel"]
  } ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?xf32>) {
  ^bb0(%in: f32, %out: f32):
    %2 = math.absf %in : f32
    linalg.yield %2 : f32
  } -> tensor<?xf32>
  return %1 : tensor<?xf32>
}

The inferred result type for this tosa.abs op is tensor<?xf32>. This type is used by the tensor.empty op for the creation of the output tensor. The occurrence of a dynamic dimension in the inferred output type forces us to provide additional information about its runtime size to the tensor.emtpy op. This work is done by an additional tensor.dim op, which infers the result size from the size of %arg0.

Inferred vs actual result type mismatch

The tensor type used in the tensor.empty and later in the linalg.generic ops is purely inferred from input operands, and does not necessarily match the actual result type used in the original TOSA op. In this example, TOSA code

func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor<?xf32> {
  %0 = "tosa.abs"(%arg0) : (tensor<5xf32>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

converts to

#map = affine_map<(d0) -> (d0)>
func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor<?xf32> {

  // Perform element-wise computation
  %0 = tensor.empty() : tensor<5xf32>
  %1 = linalg.generic {
    indexing_maps = [#map, #map],
    iterator_types = ["parallel"]
  } ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5xf32>) {
  ^bb0(%in: f32, %out: f32):
    %2 = math.absf %in : f32
    linalg.yield %2 : f32
  } -> tensor<5xf32>

  // Cast result
  %cast = tensor.cast %1 : tensor<5xf32> to tensor<?xf32>
  return %cast : tensor<?xf32>
}

The result type inference performed by the Broadcastable trait deduces an output tensor of type tensor<5xf32>, while the original op uses the more general type tensor<?xf32> as its result type. The result type compatibility rules specified by the Broadcastable trait tolerate this combination, and the lowering pass resolves this type difference by introducing an additional tensor.cast op.

Fully static broadcasting

The tosa.add op below uses two 1D tensors with static dimensions as input operands. According to the broadcast semantics specified by the Broadcastable trait, the dimension of size 1 in %arg0 must be broadcast to all elements of a tensor matching the size of %arg1 before the element-wise computation can proceed. The following code

func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
  return %0 : tensor<3xf32>
}

converts to

#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
  %0 = tensor.empty() : tensor<3xf32>
  %1 = linalg.generic {
    indexing_maps = [#map, #map1, #map1],
    iterator_types = ["parallel"]
  } ins(%arg0, %arg1 : tensor<1xf32>, tensor<3xf32>) outs(%0 : tensor<3xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %2 = arith.addf %in, %in_0 : f32
    linalg.yield %2 : f32
  } -> tensor<3xf32>
  return %1 : tensor<3xf32>
}

Instead of explicitly expanding %arg0 into a tensor of type tensor<3xf32>, the lowering strategy relies on the use of two different affine maps (#map and #map1) in the linalg.generic op. The former indicates that every execution of the linalg.generic body should access element 0 of tensor %arg0, while the latter accesses all 3 elements of %arg1, effectively honoring the desired broadcast behavior.

Dynamic-to-static broadcasting

The tosa.add op below has two 1D tensors as input operands, where one has a static dimension greater than 1 and the other has a dynamic dimension. The lowering pass must consider the following runtime sizes for the dynamic dimension of %arg1:

  • If it is 1, broadcast semantics apply. Tensor %arg1 must be expanded into a 5-element tensor before the element-wise computation can proceed.
  • If it is 5, the element-wise computation must be performed directly on the input operands.
  • If it has any other size, the behavior is undefined.

The following code

func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor<?xf32>) -> tensor<5xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<5xf32>, tensor<?xf32>) -> tensor<5xf32>
  return %0 : tensor<5xf32>
}

converts to

#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor<?xf32>) -> tensor<5xf32> {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index

  // Broadcast dimension 0 of %arg0
  %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
  %0 = arith.cmpi eq, %dim, %c1 : index
  %1 = scf.if %0 -> (tensor<?xf32>) {
    %4 = tensor.empty() : tensor<5xf32>
    %5 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel"]
    } ins(%arg1 : tensor<?xf32>) outs(%4 : tensor<5xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<5xf32>
    %cast = tensor.cast %5 : tensor<5xf32> to tensor<?xf32>
    scf.yield %cast : tensor<?xf32>
  } else {
    scf.yield %arg1 : tensor<?xf32>
  }

  // Execute element-wise operation
  %2 = tensor.empty() : tensor<5xf32>
  %3 = linalg.generic {
    indexing_maps = [#map1, #map1, #map1],
    iterator_types = ["parallel"]
  } ins(%arg0, %1 : tensor<5xf32>, tensor<?xf32>) outs(%2 : tensor<5xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %4 = arith.addf %in, %in_0 : f32
    linalg.yield %4 : f32
  } -> tensor<5xf32>
  return %3 : tensor<5xf32>
}

The lowering pass introduces a new section in the generated code whose purpose is broadcasting dimension 0 of %arg1 if it is determined to have a runtime size of 1. Operations tensor.dim and arith.cmpi evaluate this condition, while scf.if introduces the appropriate control flow divergence based on the outcome. If broadcasting is confirmed, an additional linalg.generic op replicates the single element in %arg1 5 times. If broadcasting is not applicable, tensor %arg1 is used as is.

Static-to-dynamic broadcasting

The tosa.add op below uses two 1D tensors as inputs where one has a static dimension of size 1 and the other has a dynamic dimension. One might conceive two distinct execution scenarios: one in which the dynamic dimension has a runtime size of 1, and one in which it is greater than 1. The former involves a direct element-wise computation, while the latter needs to apply a broadcasting action beforehand. However, both scenarios may be expressed as one in which the single element in the static dimension is broadcast to every position of the dynamic dimension, whether such dimension is of size 1 or more.

The following code

func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<?xf32>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

converts to

#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {

  // Create output tensor
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
  %0 = tensor.empty(%dim) : tensor<?xf32>

  // Perform element-wise computation
  %1 = linalg.generic {
    indexing_maps = [#map, #map1, #map1],
    iterator_types = ["parallel"]
  } ins(%arg0, %arg1 : tensor<1xf32>, tensor<?xf32>) outs(%0 : tensor<?xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %2 = arith.addf %in, %in_0 : f32
    linalg.yield %2 : f32
  } -> tensor<?xf32>
  return %1 : tensor<?xf32>
}

The resulting code does not require the introduction of control flow, as it behaves correctly for any runtime size of %arg0. The generated code is similar to the fully static broadcasting example, except the size of the output tensor must be computed at runtime with an additional tensor.dim op.

Fully dynamic broadcasting

The tosa.add op below uses two 1D tensors as input operands, both of which have dynamic dimensions. These are the possible execution scenarios to consider:

  • In the straightforward case that both operands have the same runtime size, the element-wise computation is carried out with the original tensors.
  • If the size of %arg0 is 1 and the size of %arg1 is greater than 1, broadcast semantics apply to %arg0. Its single element must be replicated as many times as elements are present in %arg1 before the element-wise computation can proceed.
  • Conversely, if the size of %arg1 is 1 and the size of %arg0 is greater than 1, broadcast semantics apply to %arg1.
  • If both operands have a runtime size greater than 1 but these sizes are not equal, the operation causes undefined behavior.

The following code

func.func @test_add_1d_all_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

converts to

#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
func.func @test_add_1d_all_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {

  // Calculate maximum dimension size
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
  %dim_0 = tensor.dim %arg1, %c0 : tensor<?xf32>
  %0 = arith.maxui %dim, %dim_0 : index

  // Broadcast dimension 0 of %arg0
  %c1 = arith.constant 1 : index
  %dim_1 = tensor.dim %arg0, %c0 : tensor<?xf32>
  %1 = arith.cmpi eq, %dim_1, %c1 : index
  %2 = scf.if %1 -> (tensor<?xf32>) {
    %7 = tensor.empty(%0) : tensor<?xf32>
    %8 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel"]
    } ins(%arg0 : tensor<?xf32>) outs(%7 : tensor<?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<?xf32>
    scf.yield %8 : tensor<?xf32>
  } else {
    scf.yield %arg0 : tensor<?xf32>
  }

  // Broadcast dimension 0 of %arg1
  %dim_2 = tensor.dim %arg1, %c0 : tensor<?xf32>
  %3 = arith.cmpi eq, %dim_2, %c1 : index
  %4 = scf.if %3 -> (tensor<?xf32>) {
    %7 = tensor.empty(%0) : tensor<?xf32>
    %8 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel"]
    } ins(%arg1 : tensor<?xf32>) outs(%7 : tensor<?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<?xf32>
    scf.yield %8 : tensor<?xf32>
  } else {
    scf.yield %arg1 : tensor<?xf32>
  }

  // Perform element-wise computation
  %5 = tensor.empty(%0) : tensor<?xf32>
  %6 = linalg.generic {
    indexing_maps = [#map1, #map1, #map1],
    iterator_types = ["parallel"]
  } ins(%2, %4 : tensor<?xf32>, tensor<?xf32>) outs(%5 : tensor<?xf32>) {
  ^bb0(%in: f32, %in_3: f32, %out: f32):
    %7 = arith.addf %in, %in_3 : f32
    linalg.yield %7 : f32
  } -> tensor<?xf32>
  return %6 : tensor<?xf32>
}

The strategy here consists in first calculating the maximum runtime size for dimension 0 of both input operands. Each operand is then broadcast to the maximum runtime size if necessary. After both operands are visited, they are guaranteed to have matching runtime dimensions, and the element-wise computation may proceed.

Multi-dimensional fully dynamic broadcasting

The previous example is extended here by using 2D tensors as input operands with all dynamic dimensions. The following code

func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0, 0)>
func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {

  // Calculate maximum dimension 0
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
  %0 = arith.maxui %dim, %dim_0 : index

  // Calculate maximum dimension 1
  %c1 = arith.constant 1 : index
  %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %dim_2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
  %1 = arith.maxui %dim_1, %dim_2 : index

  // Broadcast dimension 0 of %arg0
  %dim_3 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %2 = arith.cmpi eq, %dim_3, %c1 : index
  %3 = scf.if %2 -> (tensor<?x?xf32>) {
    %dim_7 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
    %13 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%arg0 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<?x?xf32>
    scf.yield %13 : tensor<?x?xf32>
  } else {
    scf.yield %arg0 : tensor<?x?xf32>
  }

  // Broadcast dimension 1 of %arg0
  %dim_4 = tensor.dim %3, %c1 : tensor<?x?xf32>
  %4 = arith.cmpi eq, %dim_4, %c1 : index
  %5 = scf.if %4 -> (tensor<?x?xf32>) {
    %dim_7 = tensor.dim %3, %c0 : tensor<?x?xf32>
    %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
    %13 = linalg.generic {
      indexing_maps = [#map2, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%3 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<?x?xf32>
    scf.yield %13 : tensor<?x?xf32>
  } else {
    scf.yield %3 : tensor<?x?xf32>
  }

  // Broadcast dimension 0 of %arg1
  %dim_5 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
  %6 = arith.cmpi eq, %dim_5, %c1 : index
  %7 = scf.if %6 -> (tensor<?x?xf32>) {
    %dim_7 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
    %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
    %13 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%arg1 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<?x?xf32>
    scf.yield %13 : tensor<?x?xf32>
  } else {
    scf.yield %arg1 : tensor<?x?xf32>
  }

  // Broadcast dimension 1 of %arg1
  %dim_6 = tensor.dim %7, %c1 : tensor<?x?xf32>
  %8 = arith.cmpi eq, %dim_6, %c1 : index
  %9 = scf.if %8 -> (tensor<?x?xf32>) {
    %dim_7 = tensor.dim %7, %c0 : tensor<?x?xf32>
    %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
    %13 = linalg.generic {
      indexing_maps = [#map2, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%7 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<?x?xf32>
    scf.yield %13 : tensor<?x?xf32>
  } else {
    scf.yield %7 : tensor<?x?xf32>
  }

  // Perform element-wise computation
  %10 = tensor.empty(%0, %1) : tensor<?x?xf32>
  %11 = linalg.generic {
    indexing_maps = [#map1, #map1, #map1],
    iterator_types = ["parallel", "parallel"]
  } ins(%5, %9 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%10 : tensor<?x?xf32>) {
  ^bb0(%in: f32, %in_7: f32, %out: f32):
    %12 = arith.addf %in, %in_7 : f32
    linalg.yield %12 : f32
  } -> tensor<?x?xf32>
  return %11 : tensor<?x?xf32>
}

First, the computation of the maximum dimension size is done separately for dimensions 0 and 1. Then, conditional broadcasting actions are carried out for all input operands and dimensions: dimension 0 of %arg0, dimension 1 of %arg0, dimension 0 of %arg1, and dimension 1 of %arg1. Once all runtime sizes are guaranteed to be equal, the element-wise computation proceeds.

A subtle behavior to keep under control in fully dynamic multi-dimensional broadcasts is that every broadcast operation must affect exactly one dimension in one operand, while making sure that the sizes of all other dynamic dimensions in that operand are preserved in their current intermediate state. This is reflected in the fact that every tensor.empty op preceding a broadcasting linalg.generic op does enforce the new maximum dimension size for the broadcast dimension, but preserves the size of its neighboring dimension by first querying it with a tensor.dim op.

Zero-dimensional tensors

The tosa.add below op uses zero-dimensional tensors in its input operands and result, while the generated code behaves in an equivalent manner to a scalar binary operation. The following code

func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  return %0 : tensor<f32>
}

converts to

#map = affine_map<() -> ()>
func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
  %0 = tensor.empty() : tensor<f32>
  %1 = linalg.generic {
    indexing_maps = [#map, #map, #map],
    iterator_types = []
  } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %2 = arith.addf %in, %in_0 : f32
    linalg.yield %2 : f32
  } -> tensor<f32>
  return %1 : tensor<f32>
}

Rank expansion

The tosa.add op below has input tensors with different ranks. A rank equalization stage must first reshape %arg0 of type tensor<3x4xf32> into a 3D tensor of type tensor<1x3x4xf32>. Dimension 0 of size 1 in the expanded 3D tensor must then be broadcast. The expanded, broadcast version of %arg0 is finally of type tensor<2x3x4xf32>, which is ready for the element-wise computation with %arg1.

The following code

func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
  return %0 : tensor<2x3x4xf32>
}

converts to

#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {

  // Expand shape of %arg0 to match %arg1
  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>

  // Perform element-wise operation
  %0 = tensor.empty() : tensor<2x3x4xf32>
  %1 = linalg.generic {
    indexing_maps = [#map, #map1, #map1],
    iterator_types = ["parallel", "parallel", "parallel"]
  } ins(%expanded, %arg1 : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%0 : tensor<2x3x4xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %2 = arith.addf %in, %in_0 : f32
    linalg.yield %2 : f32
  } -> tensor<2x3x4xf32>
  return %1 : tensor<2x3x4xf32>
}

In the generated code, rank equalization is carried out by an additional tensor.expand_shape op. The broadcast action on the new 1-sized dimension of %arg0 is accomplished through the use of modified affine map #map, whose left-most affine expression is set to constant 0.

Ternary operation

This example illustrates the lowering of the tosa.select op, which currently acts as the sole representative of ternary element-wise ops in the TOSA dialect. The following code

func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {
  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
  return %0 : tensor<2x?xf32>
}

converts to

#map = affine_map<(d0, d1) -> (d0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {

  // Compute maximum dimension 1 among operands
  %c1 = arith.constant 1 : index
  %dim = tensor.dim %arg0, %c1 : tensor<2x?xi1>
  %dim_0 = tensor.dim %arg1, %c1 : tensor<2x?xf32>
  %0 = arith.maxui %dim, %dim_0 : index
  %dim_1 = tensor.dim %arg2, %c1 : tensor<2x?xf32>
  %1 = arith.maxui %0, %dim_1 : index

  // Broadcast dimension 1 of %arg0
  %dim_2 = tensor.dim %arg0, %c1 : tensor<2x?xi1>
  %2 = arith.cmpi eq, %dim_2, %c1 : index
  %3 = scf.if %2 -> (tensor<2x?xi1>) {
    %10 = tensor.empty(%1) : tensor<2x?xi1>
    %11 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%arg0 : tensor<2x?xi1>) outs(%10 : tensor<2x?xi1>) {
    ^bb0(%in: i1, %out: i1):
      linalg.yield %in : i1
    } -> tensor<2x?xi1>
    scf.yield %11 : tensor<2x?xi1>
  } else {
    scf.yield %arg0 : tensor<2x?xi1>
  }

  // Broadcast dimension 1 of %arg1
  %dim_3 = tensor.dim %arg1, %c1 : tensor<2x?xf32>
  %4 = arith.cmpi eq, %dim_3, %c1 : index
  %5 = scf.if %4 -> (tensor<2x?xf32>) {
    %10 = tensor.empty(%1) : tensor<2x?xf32>
    %11 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%arg1 : tensor<2x?xf32>) outs(%10 : tensor<2x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<2x?xf32>
    scf.yield %11 : tensor<2x?xf32>
  } else {
    scf.yield %arg1 : tensor<2x?xf32>
  }

  // Broadcast dimension 1 of %arg2
  %dim_4 = tensor.dim %arg2, %c1 : tensor<2x?xf32>
  %6 = arith.cmpi eq, %dim_4, %c1 : index
  %7 = scf.if %6 -> (tensor<2x?xf32>) {
    %10 = tensor.empty(%1) : tensor<2x?xf32>
    %11 = linalg.generic {
      indexing_maps = [#map, #map1],
      iterator_types = ["parallel", "parallel"]
    } ins(%arg2 : tensor<2x?xf32>) outs(%10 : tensor<2x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<2x?xf32>
    scf.yield %11 : tensor<2x?xf32>
  } else {
    scf.yield %arg2 : tensor<2x?xf32>
  }

  // Perform element-wise computation
  %8 = tensor.empty(%1) : tensor<2x?xf32>
  %9 = linalg.generic {
    indexing_maps = [#map1, #map1, #map1, #map1],
    iterator_types = ["parallel", "parallel"]
  } ins(%3, %5, %7 : tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) outs(%8 : tensor<2x?xf32>) {
  ^bb0(%in: i1, %in_5: f32, %in_6: f32, %out: f32):
    %10 = arith.select %in, %in_5, %in_6 : f32
    linalg.yield %10 : f32
  } -> tensor<2x?xf32>
  return %9 : tensor<2x?xf32>
}

The generated code first obtains the maximum runtime size for the dynamic dimension (dimension 1) among the 3 input operands. Then, each input is individually broadcast into that maximum size if its dynamic dimension is identified to have a runtime size of 1. After all tensors are guaranteed to have equal sizes, the element-wise computation is performed.

Additional notes

The following issues have been addressed in the current version of this RFC since it was originally posted:

  • The RFC originally suggested to remove broadcast semantics from input dynamic dimensions for simplicity. Previous feedback suggested that it is important to support this feature, which is now reflected in the current proposal.

The following issues remain open for discussion and may involve further modifications in this RFC:

  • The proposed specification for the Broadcastable op trait removes broadcast semantics for the inferred result for simplicity (labeled as Modification 1 in the RFC). Is this change acceptable?
  • The proposed specification for the Broadcastable op trait forbids implicit dynamic-to-static cast in result dimensions (labeled Modification 2 in the RFC). Is this change acceptable?
  • The existing implementation includes two locations where result dimensions are inferred for element-wise ops. One is function OpTrait::util::getBroadcastedShape(), invoked during op verification for ops with the Broadcastable trait. The other is OP::inferReturnTypeComponents(). It seems there are opportunities to factor out some code here.
2 Likes

I do not work with TOSA but I can share how we handled broadcasting on dynamic shapes in our numpy → linalg pipileine in our python compiler (Open MLIR Meeting 6/8/2023: Update on Numba/MLIR)

We expand each individual dimension for each arg, checking dynamic values of dims:

%dim0 = dim %arg1, 0
%dim1 = dim %arg1, 1
%0 = scf.if (%dim0 == 1) {
  %1 = expand dim 0 %arg1
  yield %1
} else {
  yield %arg1
}
%broadcasted_arg1 = scf.if (%dim1 == 1) {
  %3 = expand dim 1 %0
  yield %3
} else {
  yield %0
}

... same for arg2 ...

linalg.generic(%broadcasted_arg1, %broadcasted_arg2)

(we assume that shapes are broadcastable, e.g. (5,7)x(8,9) will produce UB)
For fully static shapes those if checks will be trivially folded and this code is usually fused to single linalg.generic
For dynamic shapes, this code will still produce correct result, but it will be very slow as it can potentially involve multiple data copies for each arg.
To alleviate this we have a custom pass, which tried to propagate shape range information through dynamicaly-shaped ops (numba-mlir/mlir/lib/Transforms/ShapeIntegerRangePropagation.cpp at main · numba/numba-mlir · GitHub), this is usually enough to fold those %dim == 1, %dim != 1 checks.

Mod 1: Removing broadcast semantics for dynamic input dimensions.

As I mentioned on the Phrabricator review, I think this is a rather surprising divergence in behavior between statically known/unknown dimensions.

Under this specification, erasing the static type information on the example below radically changes the set of values that are legal inputs to the function at runtime. In fact, the dynamic version cannot consume any of the inputs that were legal to the fully static version.

func.func @example_static(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
  return %0 : tensor<2xf32>
}

// vs

func.func @example_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

A reasonable expectation would be that the set of runtime legal inputs to example_dynamic would be a strict superset of the runtime legal inputs to example_static, but that is not the case under the proposed change to the spec.

Thank you all for your feedback. I have introduced broadcast semantics for input dynamic dimension in the lowering strategy. Unfortunately Discourse no longer allows me to edit the RFC with this change due to my beginner “trust level” (Understanding Discourse Trust Levels). If anyone has any advice on how to bypass this restriction, it would be appreciated.

@mehdi_amini

1 Like

Hi!
I think some of the problems are caused by the fact that for most TOSA ops - the shape is not really inferred. @aviadc and I fixed some ops, but a general fix should be done.
Example of a fix: ⚙ D146132 [mlir] tosa.concat - Add InferTensorType interface
Lately we also understood that isCompatibleReturnTypes should consider which of the types is the inferred and which is the returned when comparing two types.
So I agree with the changes you want to do. Regarding the implementation, it might be easier by fixing the shape inference issue.
I uploaded a draft patch so you can see what I mean: ⚙ D154222 [mlir] Draft: Fix shape inference of tosa elementwise binary ops. Perhaps the new trait will not be necessary/should have less things to take care of.

Hi everyone. I have updated the RFC according to the feedback received from multiple sources over the last few days.

  • The main update in this version involved adding support for broadcast semantics in dynamic dimensions. @Hardcode84: Thanks for the tip, I used the strategy you’re proposing to generate conditional broadcasting based on the runtime size of a dynamic dimension.
  • @mehdi_amini: Thank you for the upgrade - that solved the edit permissions problem.
  • @maya_amrami: Thanks for the suggestion. The updated proposed implementation (Phabricator link on RFC) includes the shape inference as part of the Broadcastable trait. The new documentation in Broadcastable.md describes this behavior in detail, which is intended to apply for unary/binary/ternary ops in a unified manner. I added a note at the end of the RFC pointing to multiple locations in the code where shape inference is currently performed, which we might want to factor out at some point. Given the magnitude of this proposal, we might want to postpone this task until after this change lands, though.

Thanks for the nice write up here, will be great for all edge cases to be handled!

Yes we need to improve this, for now simplest is where these overlap sufficiently that one can be implemented in terms of the other, one can make little shims and use OpTraitList to group traits (look at InferTensorType as an example). In this case it sounds like anything that implements broadcastable trait can refer to same type inference function, now here it could either be a vector or tensor as result (from trait doc) so one would need an additional piece of information. Not blocker indeed, but something that wouldn’t be too hard to add to ODS generator (it fits naturally with something else I’ve wanted).

1 Like

The current version of this RFC as reflected in the main branch today forbids implicit dynamic-to-static cast in result dimensions. This is what’s currently marked as Modification 2 in today’s version of the RFC. This change turned out to be problematic on the Tensorflow side and has now been undone in the following patch:

https://reviews.llvm.org/D156714

@mehdi_amini It looks like my current Discourse trust level does not allow me to edit this post after 30 days since the original write-up. If this permission issue can be easily addressed, I’m happy to edit the RFC to reflect the latest state of the Broadcastable trait.