[RFC] Parallel Abstraction For Tensors and Buffers
Co-authored by: @apaszke
There is currently no good way of representing IR that involves parallelism with explicit thread/processor ids, tensors types and subset abstractions.
For more context on the discussion and alternatives, see the prior discussion in which alternatives involving scf.for
, scf.parallel
and the now retired linalg.tile_loop
have been discussed at length.
This limits the ability to write transformations that benefit from subset abstractions on tensor SSA values.
As a reminder, structured codegen on tensors has been shown to provide a low-surprise codegen path that runs at close to peak (sequential) performance.
The object of this RFC is to propose a retargetable abstraction that supports parallelism with explicit thread/processor ids and operates on either tensors or buffers.
This will allow scaling sequential structured codegen to the parallel case and be reusable with either the async
or the GPU-inspired processor grid
models of execution.
Proposed abstraction
At a high-level, the proposed ops specify that a region that is evaluated multiple times in parallel, once per value of the single associated thread_id
block argument. It represents a target-independent parallel function application operation.
The operation is not isolated from above and captures every value implicitly.
The proposed abstraction is a combination of atoms proposed during the prior discussion, with additional insights coming from mixing them to build an end-to-end prototype.
The proposed abstraction is called xxx.foreach_thread
, with xxx
denoting uncertainty about the dialect in which this should live.
Side Effecting Form
A parallel matmul on buffers can be expressed using xxx.foreach_thread
as follows:
//
// Sequential context.
//
xxx.foreach_thread (%thread_id_1, %thread_id_2) in (%num_threads_1, %num_threads_2) {
//
// Parallel context, each thread with id = (%thread_id_1, %thread_id_2) runs its version of the code.
//
// f, g, h represent pseudo-IR, the details of which chunk of matmul is read/written by a
// particular thread are irrelevant in this example.
%sA = memref.subview %A[f((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
%sB = memref.subview %B[g((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
%sC = memref.subview %C[h((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
matmul ins(%sA, %sB) outs(%sC)
}
//
// Sequential context.
//
The details of which chunk of matmul is read/written by a particular thread is irrelevant.
This can be further lowered to materializing parallelism with e.g. nested async
dialect as follows:
//
// Sequential context.
//
%group_1 = async.create_group %num_threads_1: !async.group
scf.for %thread_id_1 = 0 to %num_threads_1 step %c1 {
%token_1 = async.execute {
//
// Parallel context, each thread with id = %thread_id_1 runs its version of the code.
//
%group_2 = async.create_group %num_threads_2: !async.group
scf.for %thread_id_2 = 0 to %num_threads_2 step %c1 {
%token_2 = async.execute {
//
// Parallel context, each thread with id = %thread_id_2 runs its version of the code.
//
%sA = memref.subview %A[f((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
%sB = memref.subview %B[g((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
%sC = memref.subview %C[h((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
matmul ins(%sA, %sB) outs(%sC)
asyc.yields
}
async.add_to_group %token_2, %group_2 : !async.token
}
async.await_all %group_2
async.yield
}
async.add_to_group %token_1, %group_1 : !async.token
}
async.await_all %group
//
// Sequential context.
//
Or to a CUDA-like processor-grid execution model:
//
// Sequential context.
//
///
/// Surrounding kernel launch and kernel encapsulation IR omitted.
///
... {
//
// Parallel context, each thread with id = (%thread_id_1, %thread_id_2) runs its version of the code.
//
%thread_id_1 = gpu.block_dim x
%thread_id_2 = gpu.block_dim y
%sA = memref.subview %A[f((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
%sB = memref.subview %B[g((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
%sC = memref.subview %C[h((%thread_id_1, %thread_id_2))]: memref<?x?xT> to memref<?x?xT>
matmul ins(%sA, %sB) outs(%sC)
}
//
// Sequential context.
//
The order of reads and writes to memory is unspecified across iterations.
Note in particular that this form allows representing the simple signal/wait construct, for which it was previously noted that scf.parallel
is not expressive enough:
//
// Sequential context.
//
xxx.foreach_thread (%thread_id) in (%c2) {
//
// Parallel context, each thread with id = (%thread_id) runs its version of the code.
//
scf.if %thread_id == 0 {
wait();
}
scf.if %thread_id == 1 {
signal();
}
}
//
// Sequential context.
//
Pure Form
In its pure form, the op has an extensible concurrent terminator region containing explicit parallel_insert ops. The ops themselves do not create new values, rather the terminator yields a concurrently assembled tensor. This allows maintaining a clean separation between the subset and full tensor. The terminator specifies how the results of all parallel invocations should be reconciled into a full value that will be returned from xxx.foreach_thread
.
Multi-return values are encoded by including multiple operations inside the xxx.perform_concurrently
block.
//
// Sequential context.
//
%matmul_and_pointwise:2 = xxx.foreach_thread %thread_id in %num_threads -> (tensor<?x?xT>, tensor<?xT>) {
//
// Parallel context, each thread with id = (%thread_id_1, %thread_id_2) runs its version of the code.
//
%sA = tensor.extract_slice %A[f((%thread_id_1, %thread_id_2))]: tensor<?x?xT> to tensor<?x?xT>
%sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]: tensor<?x?xT> to tensor<?x?xT>
%sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]: tensor<?x?xT> to tensor<?x?xT>
%sD = matmul ins(%sA, %sB) outs(%sC)
%spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]: tensor<?xT> to tensor<?xT>
%sE = add ins(%spointwise) outs(%sD)
xxx.perform_concurrently {
// First op within the parallel terminator contributes to producing %matmul_and_pointwise#0.
xxx.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]: tensor<?x?xT> into tensor<?x?xT>
// Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1.
xxx.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]: tensor<?xT> into tensor<?xT>
}
}
//
// Sequential context.
//
The op does not use the notion of iter_args
or init
, instead the xxx.perform_concurrently
captures tensor values that are “inserted into”.
Similarly to the result of tensor.insert_slice
, the resulting tensor is a new tensor with the same value as the into
tensor except in the places updated by the xxx.parallel_insert_slice
ops.
The order of the parallel updates is unspecified.
From a bufferization perspective, the op is considered to be in destination-passing style where a result #n
and the #n
’s parallel_insert operation destination are “tied”.
Bufferization is guaranteed to occur inplace: in the example above, the buffer for %matmul_and_pointwise#0
is bufferized to the same buffer as for %C
. If any external conflicts occur (e.g. RAW conflict on %C
needed after the end of execution of the parallel op) then appropriate clones of %C
will be introduced by bufferization.
Internal conflicts (e.g. overlapping xxx.parallel_insert_slice
) are legitimate races that will turn into side-effecting races after bufferization.
Nested Forms
The following illustrates a 2-level 1-D parallel version with xxx.foreach_thread
:
//
// Sequential context.
//
%7 = xxx.foreach_thread %thread_id_1 in %c125 -> (tensor<250x1020xf32>) {
%8 = affine.apply affine_map<(d0) -> (d0 * 2)>(%thread_id)
%10 = tensor.extract_slice %6[%8, 0] [2, 1020] [1, 1] : tensor<250x1020xf32> to tensor<2x1020xf32>
%11 = xxx.foreach_thread %thread_id_2 in %c255 -> (tensor<2x1020xf32>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 4)>(%%thread_id_2)
%15 = tensor.extract_slice %10[0, %13] [2, 4] [1, 1] : tensor<2x1020xf32> to tensor<2x4xf32>
%16 = tensor.extract_slice %3[%8, 0] [2, 500] [1, 1] : tensor<250x500xf32> to tensor<2x500xf32>
%17 = tensor.extract_slice %4[0, %13] [500, 4] [1, 1] : tensor<500x1020xf32> to tensor<500x4xf32>
%18 = linalg.matmul ins(%16, %17 : tensor<2x500xf32>, tensor<500x4xf32>) outs(%15 : tensor<2x4xf32>) -> tensor<?x?xf32>
xxx.perform_concurrently {
xxx.parallel_insert_slice %18 into %10[0, %13] [2, 4] [1, 1] : tensor<2x4xf32> into tensor<2x1020xf32>
}
}
xxx.perform_concurrently {
xxx.parallel_insert_slice %11 into %6[%8, 0] [2, 1020] [1, 1] : tensor<2x1020xf32> into tensor<250x1020xf32>
}
}
//
// Sequential context.
//
After inplace bufferization, 1-D slices properly combine into a single 2-D slice, and the IR becomes:
//
// Sequential context.
//
xxx.foreach_thread %thread_id in %c125 -> () {
%3 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
%5 = memref.subview %2[%3, 0] [2, 1020] [1, 1] : memref<250x1020xf32> to memref<2x1020xf32, affine_map<(d0, d1)[s0] -> (d0 * 1020 + s0 + d1)>>
xxx.foreach_thread %thread_id_2 in %c255 -> () {
%6 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1)
%8 = memref.subview %5[0, %6] [2, 4] [1, 1] : memref<2x1020xf32, affine_map<(d0, d1)[s0] -> (d0 * 1020 + s0 + d1)>> to memref<2x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1020 + s0 + d1)>>
%9 = memref.subview %0[%3, 0] [2, 500] [1, 1] : memref<250x500xf32> to memref<2x500xf32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>
%10 = memref.subview %1[0, %6] [500, 4] [1, 1] : memref<500x1020xf32> to memref<500x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1020 + s0 + d1)>>
linalg.matmul ins(%9, %10 : memref<2x500xf32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>, memref<500x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1020 + s0 + d1)>>) outs(%8 : memref<2x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1020 + s0 + d1)>>)
}
}
//
// Sequential context.
//
By virtue of memref.subview
operations—and more generally pointer arithmetic—composing, we obtain N-D slices of data and compute after bufferization.
Multiple nested 1-D
xxx.foreach_thread
may be transformed into a single flattened n-D
xxx.foreach_thread
.
This process may require computation duplication when the nested 1-D form is not perfectly nested and is generally not a straightforward transformation.
Observations
Some observations surfaced during the design of these abstractions and during the prior discussion:
- Abstractions related to parallelism, distribution, loops and sub-tensor-SSA values only truly materialize post bufferization. The exercise is to invent representations on tensors and sub-tensor-SSA values that allow transformations to be specified and carried past bufferization without inconsistencies or abstraction gaps.
- Extra care about the full/sub-tensor type of operands is key and has historically been limited.
- A thread-first parallel iteration model is deemed important (i.e. thread ids explicit and visible). In the current incantations of MLIR, the movement from data-first to thread-first happens late in the pipeline (e.g. when lowering to GPU). New abstractions expose these concepts earlier and analyze/preserve parallelism the same way early vectorization preserves vector types. This will also avoid pigeonholing to existing dense and regular abstractions; dynamic sizes, load-balancing and future types (e.g. lists) must not be precluded. This requires going beyond a (lower bound, upper bound, step)-representation and embracing a thread-first representation.
- In addition to enabling more parallel codegen, we expect such an explicit thread-first parallel iteration model to be a first step towards distributed tensors.
Question
What xxx
dialect is appropriate for such a thread-first parallel construct ?