[RFC] Concurrent updates with `scf.foreach_thread`

`ForeachThreadOp`

is the only operation in MLIR Core that models parallel

execution on tensors. It has a `perform_concurrently`

terminator that

specifies how to combine the partial results of all parallel invocations

into a full value, in some unspecified order.

```
%result = scf.foreach_thread (%threadId) in (%numThreads)
shared_outs(%out_ = %out) -> (tensor<64xf32>) {
%offset = affine.apply ...
%size = affine.min ...
%inTile = tensor.extract_slice %in[%offset][%size][1]
: tensor<64xf32> to tensor<?xf32>
%outTile = tensor.extract_slice %out_[%offset][%size][1]
: tensor<64xf32> to tensor<?xf32>
%resultTile = linalg.elemwise_unary
ins(%inTile : tensor<?xf32>)
outs(%outTile : tensor<?xf32>) -> tensor<?xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %resultTile into %out_[%offset][%size][1]
: tensor<?xf32> into tensor<64xf32>
}
}
```

Notice, that `tensor.parallel_insert_slice`

looks like `tensor.insert_slice`

,

but has no results and the construction of the resulting tensor happens implicitly.

## Reductions

The abstraction above can model distribution with respect to parallel/batch

dimensions of ops. The problem arises when we want to tile reductions and

similar ops that require accumulation of the partial results.

```
%init = linalg.init_tensor [128] : tensor<128xf32>
%out = linalg.fill ins(%cst: f32)
outs(%init: tensor<128xf32>) -> tensor<128xf32>
%sum = linalg.reduce
ins(%in: tensor<128x512xf32>) outs(%out: tensor<128xf32>)
dimensions = [1]
(%in_elem: f32, %out_elem: f32):
%sum_elem = arith.addf %in_elem, %out_elem : f32
linalg.yield sum_elem : f32
}
```

`linalg.reduce`

has a region that specifies the combiner operation(s).

After tiling, the body of the resulting loop should contain two important parts.

The first part is the original `linalg.reduce`

op that operates on the tiles of the

original operands. The second part should specify how to combine overlapping

partial results.

By adding an *optional* region to `tensor.parallel_insert_slice`

, we can model

concurrent read-modify-like updates.

```
%neutralElement = arith.constant 0.0 : f32
%sum = scf.foreach_thread (%threadId) in (%numThreads)
shared_outs(%out_ = %out) -> (tensor<128xf32>) {
// Offset computations based of %threadId.
%i = ...
%j = ...
// Initialize tmp out.
%initTmp = linalg.init_tensor [128] : tensor<8xf32>
%outTmp = linalg.fill ins(%neutralElement: f32)
outs(%initTmp: tensor<8xf32>) -> tensor<8xf32>
// Materialize slices.
%inTile = tensor.extract_slice %in[%i, %j] [16, 8] [1, 1]
: tensor<128x512xf32> to tensor<16x8xf32>
// Compute reduction on tiles.
%sumTile = linalg.reduce
ins(%inTile: tensor<16x8xf32>) outs(%outTmp: tensor<8xf32>)
dimensions = [1]
(%in_elem: f32, %out_elem: f32):
%sum_elem = arith.addf %in_elem, %out_elem : f32
linalg.yield sum_elem : f32
}
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %sumTile into %out_[%i] [8] [1]
// Combine reduced tile with the values in %out.
acc (%new: tensor<8xf32>, %current: tensor<8xf32>) {
%combined = linalg.generic {
indexing_maps = [#id_1d, #id_1d],
iterator_types = ["parallel"]}
ins(%new: tensor<8xf32>)
outs(%current : tensor<8xf32>) {
^bb(%new_elem: f32, %current_elem: f32) :
%s = arith.addf %new_elem, %current_elem : f32
linalg.yield %s : f32
} -> tensor<8xf32>
linalg.yield %combined : tensor<8xf32>
}: tensor<?xf32> into tensor<?xf32>
}
}
}
```

In this example the *accumulator*-region contains the `linalg.generic`

op that

combines two block arguments. `%new`

corresponds to the partial

result computed in the current thread. `%current`

corresponds to values of the

current slice stored in the shared output.

## ParallelInsertSliceOp changes

We can add an optional region to `tensor.parallel_insert_slice`

OpDef.

```
let regions = (region VariadicRegion<SizedRegion<1>>:$accumulators);
```

An alternative is to make this region required and not print it when itâ€™s trivial. Then there is no clear way to distinguish between the case when we donâ€™t have overlapping writes and the case when we do, but we want to do concurrent overwrites. In the latter case, we would need to use atomic writes to avoid tears.

## What to lower it to?

It can be lowered to `std.atomic_rmw`

or some vector atomic ops directly or

to an intermediate op in MemRef dialect.

```
memref.rmw(%new: memref<8xf32>, %current: memref<8xf32>)
^bb0(%new_: memref<8xf32>, %current_: memref<8xf32>) {
%combined = linalg.generic {
indexing_maps = [#id_1d, #id_1d],
iterator_types = ["parallel"]}
ins(%new_: memref<8xf32>)
outs(%current_ : memref<8xf32>) {
^bb(%new_elem: f32, %current_elem: f32) :
%s = arith.addf %new_elem, %current_elem : f32
linalg.yield %s : f32
}
memref.yield
}
```

## TilingInterface

Tiling of reductions would require changes to the `TilingInterface`

.

It should be possible to query the op on whether any accumulation is needed and

also allow the op itseld to populate the accumulator region.

## Questions

Do we have any atomic vector operations in MLIR?

Do we need the same changes for the upcoming `tensor.parallel_scatter`

?