[RFC] Concurrent updates with scf.foreach_thread
.
ForeachThreadOp
is the only operation in MLIR Core that models parallel
execution on tensors. It has a perform_concurrently
terminator that
specifies how to combine the partial results of all parallel invocations
into a full value, in some unspecified order.
%result = scf.foreach_thread (%threadId) in (%numThreads)
shared_outs(%out_ = %out) -> (tensor<64xf32>) {
%offset = affine.apply ...
%size = affine.min ...
%inTile = tensor.extract_slice %in[%offset][%size][1]
: tensor<64xf32> to tensor<?xf32>
%outTile = tensor.extract_slice %out_[%offset][%size][1]
: tensor<64xf32> to tensor<?xf32>
%resultTile = linalg.elemwise_unary
ins(%inTile : tensor<?xf32>)
outs(%outTile : tensor<?xf32>) -> tensor<?xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %resultTile into %out_[%offset][%size][1]
: tensor<?xf32> into tensor<64xf32>
}
}
Notice, that tensor.parallel_insert_slice
looks like tensor.insert_slice
,
but has no results and the construction of the resulting tensor happens implicitly.
Reductions
The abstraction above can model distribution with respect to parallel/batch
dimensions of ops. The problem arises when we want to tile reductions and
similar ops that require accumulation of the partial results.
%init = linalg.init_tensor [128] : tensor<128xf32>
%out = linalg.fill ins(%cst: f32)
outs(%init: tensor<128xf32>) -> tensor<128xf32>
%sum = linalg.reduce
ins(%in: tensor<128x512xf32>) outs(%out: tensor<128xf32>)
dimensions = [1]
(%in_elem: f32, %out_elem: f32):
%sum_elem = arith.addf %in_elem, %out_elem : f32
linalg.yield sum_elem : f32
}
linalg.reduce
has a region that specifies the combiner operation(s).
After tiling, the body of the resulting loop should contain two important parts.
The first part is the original linalg.reduce
op that operates on the tiles of the
original operands. The second part should specify how to combine overlapping
partial results.
By adding an optional region to tensor.parallel_insert_slice
, we can model
concurrent read-modify-like updates.
%neutralElement = arith.constant 0.0 : f32
%sum = scf.foreach_thread (%threadId) in (%numThreads)
shared_outs(%out_ = %out) -> (tensor<128xf32>) {
// Offset computations based of %threadId.
%i = ...
%j = ...
// Initialize tmp out.
%initTmp = linalg.init_tensor [128] : tensor<8xf32>
%outTmp = linalg.fill ins(%neutralElement: f32)
outs(%initTmp: tensor<8xf32>) -> tensor<8xf32>
// Materialize slices.
%inTile = tensor.extract_slice %in[%i, %j] [16, 8] [1, 1]
: tensor<128x512xf32> to tensor<16x8xf32>
// Compute reduction on tiles.
%sumTile = linalg.reduce
ins(%inTile: tensor<16x8xf32>) outs(%outTmp: tensor<8xf32>)
dimensions = [1]
(%in_elem: f32, %out_elem: f32):
%sum_elem = arith.addf %in_elem, %out_elem : f32
linalg.yield sum_elem : f32
}
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %sumTile into %out_[%i] [8] [1]
// Combine reduced tile with the values in %out.
acc (%new: tensor<8xf32>, %current: tensor<8xf32>) {
%combined = linalg.generic {
indexing_maps = [#id_1d, #id_1d],
iterator_types = ["parallel"]}
ins(%new: tensor<8xf32>)
outs(%current : tensor<8xf32>) {
^bb(%new_elem: f32, %current_elem: f32) :
%s = arith.addf %new_elem, %current_elem : f32
linalg.yield %s : f32
} -> tensor<8xf32>
linalg.yield %combined : tensor<8xf32>
}: tensor<?xf32> into tensor<?xf32>
}
}
}
In this example the accumulator-region contains the linalg.generic
op that
combines two block arguments. %new
corresponds to the partial
result computed in the current thread. %current
corresponds to values of the
current slice stored in the shared output.
ParallelInsertSliceOp changes
We can add an optional region to tensor.parallel_insert_slice
OpDef.
let regions = (region VariadicRegion<SizedRegion<1>>:$accumulators);
An alternative is to make this region required and not print it when it’s trivial. Then there is no clear way to distinguish between the case when we don’t have overlapping writes and the case when we do, but we want to do concurrent overwrites. In the latter case, we would need to use atomic writes to avoid tears.
What to lower it to?
It can be lowered to std.atomic_rmw
or some vector atomic ops directly or
to an intermediate op in MemRef dialect.
memref.rmw(%new: memref<8xf32>, %current: memref<8xf32>)
^bb0(%new_: memref<8xf32>, %current_: memref<8xf32>) {
%combined = linalg.generic {
indexing_maps = [#id_1d, #id_1d],
iterator_types = ["parallel"]}
ins(%new_: memref<8xf32>)
outs(%current_ : memref<8xf32>) {
^bb(%new_elem: f32, %current_elem: f32) :
%s = arith.addf %new_elem, %current_elem : f32
linalg.yield %s : f32
}
memref.yield
}
TilingInterface
Tiling of reductions would require changes to the TilingInterface
.
It should be possible to query the op on whether any accumulation is needed and
also allow the op itseld to populate the accumulator region.
Questions
Do we have any atomic vector operations in MLIR?
Do we need the same changes for the upcoming tensor.parallel_scatter
?