Overview
Asynchronicity is exposed in many programming models and runtime environments. When driving an accelerator (like GPUs), the latency is hidden by having the device running asynchronously from the host. This is achieved by exposing asynchronous APIs from the driver/framework. We aim to represent programs where task-level concurrency spans both host and device computations.
Our main current motivation and first deliverables is the integration of TensorFlow with TFRT to target GPUs, but the design and implementation of the representation of asynchronous computation in MLIR is intended to be generic and reusable beyond this particular environment.
We have started to experiment with a model for this asynchronous behavior for GPUs using GPU specific abstractions that we plan to introduce into MLIR core. During the design, it has become clear that the model for GPU will need to compose and interoperate well with other encodings of asynchronicity at other levels of the TensorFlow stack, including across host tasks (e.g. core to TFRT).
A unified model would enable us to apply optimizations that are developed in the context of MLIR core (like the generic buffer assignment, concurrency heuristics[1] and stream assignment) to TensorFlow. It would also drive convergence in approaches for different target platforms (e.g. GPU, CPU, TPU) by providing a common framework.
The goal of this RFC is to drive a discussion on whether we want such a shared model before diving deeper into what such a model would look like.
Advantages of explicit MLIR model for asynchronous execution
The graph below depicts a coarse compilation pipeline for TensorFlow. Compiling TensorFlow has three main stages of representation, each of which enables different optimizations. Initially, we have a mostly data-flow representation of the program (blue), where we compute on values with limited, explicit side effects. We use this representation for classical compiler optimizations. Eventually, we lower into a form that makes memory explicit, so that we can reason about memory use. Finally, when targeting TFRT or GPUs, we need to express the program into an asynchronous form. In the case of using GPUs with TFRT, we can also have cases where asynchronous execution nests.
The translations between the different forms are not independent. Optimizations in the blue stage might impact memory use in the yellow stage and buffer assignment (in particular buffer reuse and optimizing for peak memory use) and impact the legal choices for asynchronous execution in the green phase. In particular buffer assignment has to make assumptions about later asynchronous execution when reasoning about buffer use or the effect its assignment has on concurrency. Similarly, the translation to TFRTâs asynchronous model has to reason about buffer aliasing when introducing asynchronicity. Both dependencies create a tight coupling between the two lowerings.
By adding an early, higher-level model of asynchronous execution, we could reduce this coupling. We would introduce asynchronous behavior at the data flow level, thereby avoiding the necessity to analyze buffers. Buffer assignment, now working on an explicitly asynchronous program, could reason easier about concurrent memory use. It could still undo earlier decisions by rewriting the explicitly asynchronous IR. For example, buffer assignment could sequentialize two asynchronous regions by adding a new dependency in order to enable buffer reuse between the two. Lastly, when lowering to TFRT or asynchronous devices, the program (and its effects) would already be asynchronous by construction, thereby reducing the need for analysis.
To reap these benefits, we would need to make the passes after introducing the asynchronous IR aware of its existence. Primarily these would be the buffer allocation optimizations and device compilers (like GPU and CPU code generation). At least some of these we expect to live in MLIR core or their own repositories outside of TensorFlow. Therefore, to enable a joint model and development of the asynchronous IR, it should be located in MLIR core[2].
Context: GPU perspective on host-side dialect and async model
The GPU host side dialect in essence consists of two parts: A way to encode the placement of buffers on devices and the transfer between host and device (memory management) as well as a model for asynchronous computation. In the CUDA API, most operations on the device are asynchronous and the host is frequently expected to run ahead (e.g. already enqueuing a memory copy before the kernel has completed computing). The GPUs ability to concurrently execute independent kernels (by using multiple streams in the case of CUDA) further complicates this. However, especially for smaller workloads, exploiting this extra asynchronicity is deemed crucial for performance.
Hence, modeling GPU computations and their interaction with the host requires a model for asynchronous execution. In our initial design, we have converged on a model where we use explicit dependency tokens between asynchronous operations to model their dependencies; an instance of the classical control-as-data approach commonly used in compilers.
Beyond GPU
The GPU proposal is only applicable to already asynchronous operations from the GPU dialect. To be able to apply the same concepts also to non-asynchronous dialects (like HLO or LinAlg), we introduce a new async.region
operation. The new operation itself is asynchronous but its region gets executed synchronously. We use the async.token
type to denote dependency values that are returned by asynchronous operations. These dependency tokens can also be consumed by asynchronous operations to denote a happens-after relationship. This gives the syntax
<results> = async.region(<dependencies>) <body region>
The terminator of the async.region
is an async.yield
. Its operands are the values returned by the surrounding async.region
. Optionally, an async.yield
might get dependency tokens as operands, which encodes that the corresponding async.region
only completes once all tokens signal completion. The tokens are not passed as results, only a single new token for the async.region
is generated implicitly by the operation.
Here is an example of what the potential IR could look like. The actual computations in the example are just for illustrative purposes. In essence, we have two matrix multiplies and two additions that got fused into a single linalg operation. The fused additions have a data dependency to the matrix multiplies but the matrix multiplies are independent of each other.
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
module {
async.func @f(%arg0: tensor<32x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<8x32xf32>, %arg3: tensor<32xf32>) -> (tensor<32x32xf32>, !async.token) {
%0:2 = "async.region"() ( {
%3 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x32xf32>) -> tensor<32x32xf32>
"async.yield"(%3) : (tensor<32x32xf32>) -> ()
}) : () -> (!async.token, tensor<32x32xf32>)
%1:2 = "async.region"() ( {
%3 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x32xf32>) -> tensor<32x32xf32>
"async.yield"(%3) : (tensor<32x32xf32>) -> ()
}) : () -> (!async.token, tensor<32x32xf32>)
%2:2 = "async.region"(%0#0, %1#0) ( {
%3 = linalg.generic {args_in = 3 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %0#1, %1#1, %arg3 {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32): // no predecessors
%4 = addf %arg4, %arg5 : f32
%5 = addf %4, %arg6 : f32
linalg.yield %5 : f32
}: tensor<32x32xf32>, tensor<32x32xf32>, tensor<32xf32> -> tensor<32x32xf32>
"async.yield"(%3) : (tensor<32x32xf32>) -> ()
}) : (!async.token, !async.token) -> (!async.token, tensor<32x32xf32>)
"async.yield"(%2#1, %2#2) : (tensor<32x32xf32>, !async.token) -> ()
}
}
We have wrapped the computations into async.region
operations. The region containing the addition consumes the two async tokens from the matrix multiplications, as it can only be executed once the other two have completed. Similarly, the function now returns the dependency token of the addition using an async.yield
, as the function completes once the addition is complete. We have also changed the function to be explicitly asynchronous by using async.func
.
This representation is still on values and the async behavior is correct as it does not violate the data-flow properties of the original program. We can lower this to buffers as follows.
module {
async.func @f(%arg0: memref<32x8xf32>, %arg1: memref<8x32xf32>, %arg2: memref<8x32xf32>, %arg3: memref<32xf32>) -> (memref<32x32xf32>, !async.token) {
%0:2 = "async.region"() ( {
%3 = alloc() : memref<32x32xf32>
"xla_lhlo.dot"(%arg0, %arg1, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
"async.yield"(%3) : (memref<32x32xf32>) -> ()
}) : () -> (!async.token, memref<32x32xf32>)
%1:2 = "async.region"() ( {
%3 = alloc() : memref<32x32xf32>
"xla_lhlo.dot"(%arg0, %arg2, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
"async.yield"(%3) : (memref<32x32xf32>) -> ()
}) : () -> (!async.token, memref<32x32xf32>)
%2:2 = "async.region"(%0#0, %1#0) ( {
%3 = alloc() : memref<32x32xf32>
linalg.generic {args_in = 3 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %0#1, %1#1, %arg3, %3 {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
%4 = addf %arg4, %arg5 : f32
%5 = addf %4, %arg6 : f32
linalg.yield %5 : f32
}: memref<32x32xf32>, memref<32x32xf32>, memref<32xf32>, memref<32x32xf32>
"async.yield"(%3) : (memref<32x32xf32>) -> ()
}) : (!async.token, !async.token) -> (!async.token, memref<32x32xf32>)
"async.yield"(%2#1, %2#2) : (tensor<32x32xf32>, !async.token) -> ()
}
}
The main change is that we now see allocations and memrefs. In a next step, this would be lowered to a device dialect, e.g., GPU (using gpux
for a yet to be created async version of the GPU dialect).
module {
async.func @f(%arg0: memref<32x8xf32>, %arg1: memref<8x32xf32>, %arg2: memref<8x32xf32>, %arg3: memref<32xf32>) -> (memref<32x32xf32>, !async.token) {
%0:2 = "async.region"() ( {
%3 = alloc() : memref<32x32xf32>
"tf.matmul"(%arg0, %arg1, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
"async.yield"(%3) : (memref<32x32xf32>) -> ()
}) : () -> (!async.token, memref<32x32xf32>)
%1:2 = "async.region"() ( {
%3 = alloc() : memref<32x32xf32> // If we know this runs on GPU, allocate there to avoid transfers.
"xla_lhlo.dot"(%arg0, %arg2, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
"async.yield"(%3) : (memref<32x32xf32>) -> ()
}) : () -> (!async.token, memref<32x32xf32>)
%2:2 = "async.region"(%0#0, %1#0) ( {
%3:2 = "gpux.alloc"() : () -> (memref<32x32xf32, 3>, !async.token)
%4:2 = "gpux.transfer"(%0#1) : (memref<32x32xf32>) -> (memref<32x32xf32, 3>, !async.token)
%5:2 = "gpux.transfer"(%1#1) : (memref<32x32xf32>) -> (memref<32x32xf32, 3>, !async.token)
%6:2 = "gpux.transfer"(%arg3) : (memref<32xf32>) -> (memref<32xf32, 3>, !async.token)
%7 = "gpux.async_launch"(%4#0, %5#0, %6#0, %3#0, %3#1, %4#1, %5#1, %6#1) : (memref<32x32xf32, 3>, memref<32x32xf32, 3>, memref<32xf32, 3>, memref<32x32xf32, 3>, !async.token, !async.token, !async.token, !async.token) -> !async.token
%8:2 = "gpux.transfer"(%3#0, %7) : (memref<32x32xf32, 3>, !async.token) -> (memref<32x32xf32>, !async.token)
"async.yield"(%8#0, %8#1) : (memref<32x32xf32>, !async.token) -> ()
}) : (!async.token, !async.token) -> (!async.token, memref<32x32xf32>)
"async.yield"(%2#1, %2#2) : (tensor<32x32xf32>, !async.token) -> ()
}
}
We now get nested asynchronicity, as the GPU dialect itself has asynchronous operations. Those use the same encoding (the !async.token
type) as the async.region
operation, which allows us to compose them. The async.region
completes when the gpux.transfer
completes, encoded by the additional token input to the async.yield
.
Using this explicit encoding would allow us to further optimize this code without reasoning about memory accesses and buffer aliasing. For example, we could eliminate the transfer at the end of the region if all later uses are transfers back to gpu. We can also decide to combine async.region
operations to avoid overhead or split them to make results available earlier.
Open Questions
During offline discussions, two questions already came up that I want to repeat here.
Should async.region
capture explicitly?
This is mostly a matter of lowering convenience vs. making rewrites of the IR harder. When capturing implicitly, it is easier to move code into async regions. However, when lowering, one needs to identify free uses of the region. As the latter is available as a helper function, not capturing explicitly seems preferable.
Should an async.region
be allowed to return multiple tokens?
In its current form, the async.region
synchronizes all contained async operations before completing. However, there are use-cases where one would want a region to complete partially once some of its results are available. At the limit, this yields a representation where every returned or written to memref has its own token, creating a lot of dependency values. A more natural way to represent this might be to split the async.region
into pieces so that intermediate results are returned by their own region.
Notes
-
Users of the GPU dialect have expressed interest to research how to extract concurrency from their workloads at the high level that is beneficial for performance. âŠď¸
-
Technically, the asynchronous IR might not be a set of actual operations and rather consist of a couple of interfaces that optimization can query and using dialects (like TFRT, GPU, OpenMP, etc.) can implement. âŠď¸