TOSA to Linalg lowering (tosa.scatter)

What is the correct way of lowering of “tosa.scatter” to Linalg for f32 data type.
I’m trying to lower the following example in mlir/test/Dialect/Tosa :

// CHECK-LABEL: scatter
func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) → tensor<13x21x3xf32> {
%0 = “tosa.scatter”(%arg0, %arg1, %arg2) : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) → tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>

mlir-opt --tosa-to-standard --tosa-to-linalg test.mlir

I don’t have the lowering, but you can get a sense of the expected behavior from the specification. (Download link at :zap: TOSA (

Here’s the pseudocode from the specification for the scatter operator:

// The following array is used to check compliance that an output position
// is modified at most once.
bool_t output_modified[N,K,C];

// Copy the values_in tensor to the values_out tensor.
// Values not written by the scatter operation are unchanged in the output.
for_each(0 <= n < N, 0 <= k < K, 0 <= c < C) {
    value_t value = tensor_read<value_t>(values_in, [N,K,C], [n,k,c]);
    tensor_write<value_t>(values_out, [N,K,C], [n, k, c], value);

// Now perform the SCATTER operation, modifying the positions from the indices tensor
for_each(0 <= n < N, 0 <= w < W, 0 <= c < C) {
    index_t k = tensor_read<index_t>(indices, [N,W], [n,w]);
    REQUIRE(0 <= k && k < K);
    REQUIRE(output_modified[n,k,c] == false);
    value_t value = tensor_read<value_t>(input, [N,W,C], [n,w,c]);
    tensor_write<value_t>(values_out, [N,K,C], [n, k, c], value);
    output_modified[n,k,c] = true;

It’s hard to represent scatter op in Linalg because there are some requirements in Linalg ops (e.g., indexing_maps, etc).

I’ve prototyped on how to compile mhlo.scatter op end-to-end in IREE. The current solution is that we created LinalgExt dialect in IREE, added a linalg_ext.scatter op. We implemented scatter → linalg_ext and linalg_ext → loops lowering. Then we can run it in IREE.

Also, we did some optimization on the op. I.e. we are able to tile and distribute the op.

During the prototyping, I added a lowering from mhlo.scatter to linalg.generic op. But it only works for specific cases and it’s inefficient. I don’t think this is the way to go, just put a note here.

1 Like