[RFC] New dialect for modelling asynchronous execution at a higher-level

Overview

Asynchronicity is exposed in many programming models and runtime environments. When driving an accelerator (like GPUs), the latency is hidden by having the device running asynchronously from the host. This is achieved by exposing asynchronous APIs from the driver/framework. We aim to represent programs where task-level concurrency spans both host and device computations.

Our main current motivation and first deliverables is the integration of TensorFlow with TFRT to target GPUs, but the design and implementation of the representation of asynchronous computation in MLIR is intended to be generic and reusable beyond this particular environment.

We have started to experiment with a model for this asynchronous behavior for GPUs using GPU specific abstractions that we plan to introduce into MLIR core. During the design, it has become clear that the model for GPU will need to compose and interoperate well with other encodings of asynchronicity at other levels of the TensorFlow stack, including across host tasks (e.g. core to TFRT).

A unified model would enable us to apply optimizations that are developed in the context of MLIR core (like the generic buffer assignment, concurrency heuristics[1] and stream assignment) to TensorFlow. It would also drive convergence in approaches for different target platforms (e.g. GPU, CPU, TPU) by providing a common framework.

The goal of this RFC is to drive a discussion on whether we want such a shared model before diving deeper into what such a model would look like.

Advantages of explicit MLIR model for asynchronous execution

The graph below depicts a coarse compilation pipeline for TensorFlow. Compiling TensorFlow has three main stages of representation, each of which enables different optimizations. Initially, we have a mostly data-flow representation of the program (blue), where we compute on values with limited, explicit side effects. We use this representation for classical compiler optimizations. Eventually, we lower into a form that makes memory explicit, so that we can reason about memory use. Finally, when targeting TFRT or GPUs, we need to express the program into an asynchronous form. In the case of using GPUs with TFRT, we can also have cases where asynchronous execution nests.

mlir-drawing-one
The translations between the different forms are not independent. Optimizations in the blue stage might impact memory use in the yellow stage and buffer assignment (in particular buffer reuse and optimizing for peak memory use) and impact the legal choices for asynchronous execution in the green phase. In particular buffer assignment has to make assumptions about later asynchronous execution when reasoning about buffer use or the effect its assignment has on concurrency. Similarly, the translation to TFRT’s asynchronous model has to reason about buffer aliasing when introducing asynchronicity. Both dependencies create a tight coupling between the two lowerings.

By adding an early, higher-level model of asynchronous execution, we could reduce this coupling. We would introduce asynchronous behavior at the data flow level, thereby avoiding the necessity to analyze buffers. Buffer assignment, now working on an explicitly asynchronous program, could reason easier about concurrent memory use. It could still undo earlier decisions by rewriting the explicitly asynchronous IR. For example, buffer assignment could sequentialize two asynchronous regions by adding a new dependency in order to enable buffer reuse between the two. Lastly, when lowering to TFRT or asynchronous devices, the program (and its effects) would already be asynchronous by construction, thereby reducing the need for analysis.


To reap these benefits, we would need to make the passes after introducing the asynchronous IR aware of its existence. Primarily these would be the buffer allocation optimizations and device compilers (like GPU and CPU code generation). At least some of these we expect to live in MLIR core or their own repositories outside of TensorFlow. Therefore, to enable a joint model and development of the asynchronous IR, it should be located in MLIR core[2].

Context: GPU perspective on host-side dialect and async model

The GPU host side dialect in essence consists of two parts: A way to encode the placement of buffers on devices and the transfer between host and device (memory management) as well as a model for asynchronous computation. In the CUDA API, most operations on the device are asynchronous and the host is frequently expected to run ahead (e.g. already enqueuing a memory copy before the kernel has completed computing). The GPUs ability to concurrently execute independent kernels (by using multiple streams in the case of CUDA) further complicates this. However, especially for smaller workloads, exploiting this extra asynchronicity is deemed crucial for performance.

Hence, modeling GPU computations and their interaction with the host requires a model for asynchronous execution. In our initial design, we have converged on a model where we use explicit dependency tokens between asynchronous operations to model their dependencies; an instance of the classical control-as-data approach commonly used in compilers.

Beyond GPU

The GPU proposal is only applicable to already asynchronous operations from the GPU dialect. To be able to apply the same concepts also to non-asynchronous dialects (like HLO or LinAlg), we introduce a new async.region operation. The new operation itself is asynchronous but its region gets executed synchronously. We use the async.token type to denote dependency values that are returned by asynchronous operations. These dependency tokens can also be consumed by asynchronous operations to denote a happens-after relationship. This gives the syntax

<results> = async.region(<dependencies>) <body region>

The terminator of the async.region is an async.yield. Its operands are the values returned by the surrounding async.region. Optionally, an async.yield might get dependency tokens as operands, which encodes that the corresponding async.region only completes once all tokens signal completion. The tokens are not passed as results, only a single new token for the async.region is generated implicitly by the operation.

Here is an example of what the potential IR could look like. The actual computations in the example are just for illustrative purposes. In essence, we have two matrix multiplies and two additions that got fused into a single linalg operation. The fused additions have a data dependency to the matrix multiplies but the matrix multiplies are independent of each other.

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>


module {
  async.func @f(%arg0: tensor<32x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<8x32xf32>, %arg3: tensor<32xf32>) -> (tensor<32x32xf32>, !async.token) {
    %0:2 = "async.region"() ( {
      %3 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x32xf32>) -> tensor<32x32xf32>
      "async.yield"(%3) : (tensor<32x32xf32>) -> ()
    }) : () -> (!async.token, tensor<32x32xf32>)
    %1:2 = "async.region"() ( {
      %3 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x32xf32>) -> tensor<32x32xf32>
      "async.yield"(%3) : (tensor<32x32xf32>) -> ()
    }) : () -> (!async.token, tensor<32x32xf32>)
    %2:2 = "async.region"(%0#0, %1#0) ( {
      %3 = linalg.generic {args_in = 3 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %0#1, %1#1, %arg3 {
      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):  // no predecessors
        %4 = addf %arg4, %arg5 : f32
        %5 = addf %4, %arg6 : f32
        linalg.yield %5 : f32
      }: tensor<32x32xf32>, tensor<32x32xf32>, tensor<32xf32> -> tensor<32x32xf32>
      "async.yield"(%3) : (tensor<32x32xf32>) -> ()
    }) : (!async.token, !async.token) -> (!async.token, tensor<32x32xf32>)
    "async.yield"(%2#1, %2#2) : (tensor<32x32xf32>, !async.token) -> ()
  }
}

We have wrapped the computations into async.region operations. The region containing the addition consumes the two async tokens from the matrix multiplications, as it can only be executed once the other two have completed. Similarly, the function now returns the dependency token of the addition using an async.yield, as the function completes once the addition is complete. We have also changed the function to be explicitly asynchronous by using async.func.

This representation is still on values and the async behavior is correct as it does not violate the data-flow properties of the original program. We can lower this to buffers as follows.

module {
  async.func @f(%arg0: memref<32x8xf32>, %arg1: memref<8x32xf32>, %arg2: memref<8x32xf32>, %arg3: memref<32xf32>) -> (memref<32x32xf32>, !async.token) {
    %0:2 = "async.region"() ( {
      %3 = alloc() : memref<32x32xf32>
      "xla_lhlo.dot"(%arg0, %arg1, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
      "async.yield"(%3) : (memref<32x32xf32>) -> ()
    }) : () -> (!async.token, memref<32x32xf32>)
    %1:2 = "async.region"() ( {
      %3 = alloc() : memref<32x32xf32>
      "xla_lhlo.dot"(%arg0, %arg2, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
      "async.yield"(%3) : (memref<32x32xf32>) -> ()
    }) : () -> (!async.token, memref<32x32xf32>)
    %2:2 = "async.region"(%0#0, %1#0) ( {
      %3 = alloc() : memref<32x32xf32>
      linalg.generic {args_in = 3 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %0#1, %1#1, %arg3, %3 {
      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32):  // no predecessors
        %4 = addf %arg4, %arg5 : f32
        %5 = addf %4, %arg6 : f32
        linalg.yield %5 : f32
      }: memref<32x32xf32>, memref<32x32xf32>, memref<32xf32>, memref<32x32xf32>
      "async.yield"(%3) : (memref<32x32xf32>) -> ()
    }) : (!async.token, !async.token) -> (!async.token, memref<32x32xf32>)
    "async.yield"(%2#1, %2#2) : (tensor<32x32xf32>, !async.token) -> ()
  }
}

The main change is that we now see allocations and memrefs. In a next step, this would be lowered to a device dialect, e.g., GPU (using gpux for a yet to be created async version of the GPU dialect).

module {
  async.func @f(%arg0: memref<32x8xf32>, %arg1: memref<8x32xf32>, %arg2: memref<8x32xf32>, %arg3: memref<32xf32>) -> (memref<32x32xf32>, !async.token) {
    %0:2 = "async.region"() ( {
      %3 = alloc() : memref<32x32xf32>
      "tf.matmul"(%arg0, %arg1, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
      "async.yield"(%3) : (memref<32x32xf32>) -> ()
    }) : () -> (!async.token, memref<32x32xf32>)
    %1:2 = "async.region"() ( {
      %3 = alloc() : memref<32x32xf32> // If we know this runs on GPU, allocate there to avoid transfers.
      "xla_lhlo.dot"(%arg0, %arg2, %3) : (memref<32x8xf32>, memref<8x32xf32>, memref<32x32xf32>) -> ()
      "async.yield"(%3) : (memref<32x32xf32>) -> ()
    }) : () -> (!async.token, memref<32x32xf32>)
    %2:2 = "async.region"(%0#0, %1#0) ( {
      %3:2 = "gpux.alloc"() : () -> (memref<32x32xf32, 3>, !async.token)
      %4:2 = "gpux.transfer"(%0#1) : (memref<32x32xf32>) -> (memref<32x32xf32, 3>, !async.token)
      %5:2 = "gpux.transfer"(%1#1) : (memref<32x32xf32>) -> (memref<32x32xf32, 3>, !async.token)
      %6:2 = "gpux.transfer"(%arg3) : (memref<32xf32>) -> (memref<32xf32, 3>, !async.token)
      %7 = "gpux.async_launch"(%4#0, %5#0, %6#0, %3#0, %3#1, %4#1, %5#1, %6#1) : (memref<32x32xf32, 3>, memref<32x32xf32, 3>, memref<32xf32, 3>, memref<32x32xf32, 3>, !async.token, !async.token, !async.token, !async.token) -> !async.token
      %8:2 = "gpux.transfer"(%3#0, %7) : (memref<32x32xf32, 3>, !async.token) -> (memref<32x32xf32>, !async.token)
      "async.yield"(%8#0, %8#1) : (memref<32x32xf32>, !async.token) -> ()
    }) : (!async.token, !async.token) -> (!async.token, memref<32x32xf32>)
    "async.yield"(%2#1, %2#2) : (tensor<32x32xf32>, !async.token) -> ()
  }
}

We now get nested asynchronicity, as the GPU dialect itself has asynchronous operations. Those use the same encoding (the !async.token type) as the async.region operation, which allows us to compose them. The async.region completes when the gpux.transfer completes, encoded by the additional token input to the async.yield.

Using this explicit encoding would allow us to further optimize this code without reasoning about memory accesses and buffer aliasing. For example, we could eliminate the transfer at the end of the region if all later uses are transfers back to gpu. We can also decide to combine async.region operations to avoid overhead or split them to make results available earlier.

Open Questions

During offline discussions, two questions already came up that I want to repeat here.

Should async.region capture explicitly?

This is mostly a matter of lowering convenience vs. making rewrites of the IR harder. When capturing implicitly, it is easier to move code into async regions. However, when lowering, one needs to identify free uses of the region. As the latter is available as a helper function, not capturing explicitly seems preferable.

Should an async.region be allowed to return multiple tokens?

In its current form, the async.region synchronizes all contained async operations before completing. However, there are use-cases where one would want a region to complete partially once some of its results are available. At the limit, this yields a representation where every returned or written to memref has its own token, creating a lot of dependency values. A more natural way to represent this might be to split the async.region into pieces so that intermediate results are returned by their own region.

Notes


  1. Users of the GPU dialect have expressed interest to research how to extract concurrency from their workloads at the high level that is beneficial for performance. ↩︎

  2. Technically, the asynchronous IR might not be a set of actual operations and rather consist of a couple of interfaces that optimization can query and using dialects (like TFRT, GPU, OpenMP, etc.) can implement. ↩︎

This is precisely the concurrency model I have in IREE, so I’m a fan :slight_smile:
In IREE we effectively allow async.regions to be nested so you can model async tasks that themselves have subtasks that may be executing in a different concurrency domain (the mapping is that flow.stream is the outer concurrency scope and then dispatch regions are the inner concurrency scopes).

FYI, We’ve been looking at a dialect which is similar, but which would model multiple devices with different dispatch queues.

I’d be interested in hearing more about that @stephenneuendorffer! The intent with the IREE flow dialect is that it could be a target for that but I haven’t fully fleshed it out yet and have been hoping that there was a higher level dialect that could be lowered into it (whether this proposal, what you’re working on, or some hybrid). Device placement, queue placement, and concurrency are all intertwined and modeling things separately creates a bunch of nasty phase ordering and cost modeling issues.

Agreed that there are challenges handling multiple devices. My sense is that the flow dialect doesn’t have an explicit notion of devices, so it would have to get lowered into another dialect that does represent devices and their event dependencies. Our particular interest is in modeling data movement explicitly… So 3 jobs might be enqueued to 3 devices. The first device computes into buffer A. The second device (a DMA) moves the data from buffer A to buffer B, where the third device computes on it. The ordering dependency between these jobs needs to get explicitly represented.

Awesome!
We’re working on TFRT as our current target, but wanted something that generalizes at a layer above and having input from you about how this can map to Flow on this is a critical aspect to me.

What if instead of returning tuples of value + token

... -> (tensor<32x32xf32>, !async.token)

async regions will return async values?

... -> async.value<tensor<32x32xf32>>

and async.token will be an “alias” for async.value<>. Is it any fundamental reason not to do that?

Then instead of returning multiple tokens, async region can return multipe async values that potentially could become ready at different points in time. In the TFRT world this is required for example for FusedBatchNorm kernel, it has three outputs, they all share small computation in the beginning, but after that they are independent.

Great! I hope we can evolve this proposal so it can target IREE as a runtime, too. After all, that is why I’d like to have this in MLIR core and not be TFRT specific.

This is modeled in this proposal by having the outer scope in async.region and the inner scope handled by the device dialect. However, there is no reason why async.region could not be nested. We also have async.func that provides yet another level of asynchronicity. I’d expect that different runtimes will allow for different nesting levels.

I have been wondering whether device placement should be orthogonal or whether it has to be built into the async model. Abstractly, it does not matter on which device code runs to reason about whether it potentially runs concurrently. But I agree with @benvanik that one likely wants to take resource constraints into account when planning concurrency and doing device mapping. So these are not separate in that sense.

Still, would it be a good enough solution if device placement was, e.g., an attribute on async.region or maybe even a concept in the lower level dialects (like the GPU dialect could have a notion of device) and then later passes would take this into account and rewrite the IR? Or is it preferable to create the device mapping early on?

One could also envision device mapping at the very high level, where the user splits the program into functions that run on different devices and makes it explicit that way. So, as you see, I have more questions than answers.

Maybe we should have an ODM like meeting to talk about these issues?

What do you mean by explicit here? Will the shape checks (with the associated ThrowError etc.) have been pulled out of the ops by this time? Even so, the remaining ops (that do not implicitly do shape checks) will have conditional UB right?

Did we consider a representation that does not return the tensor from the async.region, but makes these values available only in the async region that waits on the corresponding !async.token? Not sure how we’ll do this syntactically; perhaps we can make the tensors available as implicit arguments into the region attached to the consumer async.region?

If we want to make this a purely syntactic representation without any semantic meaning then maybe we should get rid of the tokens?

An alternative is to not have async regions in a graph that operates on tensors and instead to have two steps that do the following:

  1. Have a trivial bufferization pass that does not do any buffer coalescing (so maximally concurrent). This still happens in a representation without async.region steps.
  2. Followed by a pass that inserts async regions. This will have to be conservatively correct, but we can write this pass without trying to be clever around coalesced buffers (so if the input IR has coalesced buffers then this pass won’t work “well” but would still be correct).

Just to be clear, it won’t literally do a host/device sync when running on a GPU right?

(On the GPU the token should be lowered to a stream id + an optional CUDA event. Async regions depending on the token can either enqueue their work on the same stream or wait on the CUDA event on another stream, if they so like.)

IMO this is an important use case we should support. Otherwise creating async regions becomes a harder problem since larger regions == more optimization but less concurrency.

Depending on the platform targetted that is unavoidable (e.g., HW where it can’t communicate when a part of the program is completed or where we can only flag N events and so support some number of partial results). And not necessarily bad as performance is the goal not concurrency. The best performance may have 0 concurrency. What I like about the approach wrt splitting async.region is that one then has explicit control and can incur the overhead of tracking completions where needed rather than accidentally/always.

async.regions express what can execute async from one another, forming them seems simple as it is a property of the program (where one comes from a higher level data flow abstraction simplifies this too). Merging them is also simple, but adds a constraint and so one has to consider when to do it, but the how seems simple given the execution expressed, splitting would also require some analysis (side-effect free ones would be easy).

How would that work if only some values produced are produced? E.g., say we have async.region that consumes results from multiple async.regions but only some of the values then it would seem we lose use-def tracking here and have to jump through extra hoops to find out if a value produced is actually used. Being explicit seems easier. Of course one could consider eliding it in the pretty print form …

That sounds like a good idea.

Nice, thanks @herhut! Having nested async regions is useful. For the example, I can immediately map the outer async region to vkQueueSubmiting a command buffer with semaphores for those tokens; and the inner gpux.async_launch will be corresponding to vkCmdDispatch and tokens there will become pipeline barriers. I think there are details to be worked out for those tokens regarding their fine semantics, as for mapping to lower level primitives may encounter limitations, but I guess that can be figured out later given this RFC is more for getting consensus.

Thanks to everyone that joined the ODM session on async dialects yesterday (slides and recording). There were a couple discussion points that we could not finish yesterday and also the proposal has evolved a little compare to what was written here. I will start with this post to update the proposal and then follow with more posts over time addressing some of the discussion points.

We have evolved the proposal from the specific operations posted originally into two operation interfaces that are used to describe the dependencies between operations. As an example snippet, I will use the mixed async + gpu form from the slides

%2 = "async.region"(%0#1, %1#1) ( {
  ^bb0(%t: !async.token):
  %3:2 = "gpux.alloc"(%t) : (!async.token) -> (memref<1024xf32, 5>, !async.token)
  %4 = "gpux.transfer"(%0#0, %3#0, %3#1) : (..., !async.token) -> !async.token
  %5:2 = "gpux.alloc"(%t) : (!async.token) -> (memref<1024xf32, 5>, !async.token)
  %6 = "gpux.transfer"(%1#0, %5#0, %5#1) : (..., !async.token) -> !async.token
  %7:2 = "gpux.alloc"(%t) : () -> (memref<1024xf32, 5>, !async.token)
  %8 = "gpux.launch"(%4, %6, %7#1) : (!async.token, ...) -> !async.token
 %9 = "gpux.transfer"(%7#0, %arg3, %8) : (..., !async.token) -> !async.token
  "async.return"(%9) : (!async.token) -> ()
}) : (!async.token, !async.token) -> !async.token

In the above example, all operations are asynchronous and dependencies are encoded using tokens. For an analysis to query these dependencies, we propose a new interface AsyncOpInterface with the following two methods

// Query the operands that represent async dependency tokens.
OperandRange getAsyncDependencies();
// Query the result that represents the async token for other ops to depend on.
OpResult getAsyncToken();

As an example, querying the gpux.launch operation would yield %4, %6 and %7#1 for the getAsyncDependencies and %8 for the getAsyncToken.

While this interface suffices to query dependencies between operations, it is not good enough to understand dependencies of operations within a region to the regions context. In the above example, this would be the flow of dependencies from %0#1 and %1#1 to the %t block argument of the region and the flow of dependencies from the returned %9 to the result token %2. To be able to query this, we suggest a AsyncRegionOpInterface interface

// Query the block argument that represents dependencies on the
// surrounding context or predecessor regions.
BlockArgument getAsyncRegionToken(int index)
// Query the values that represent dependencies for the surrounding
// context or successor regions.
ValueRange getAsyncRegionDependencies(int index)

For the above example, getAsyncRegionToken(0) would return %t and getAsyncRegionDependencies(0) would return %9.

The interface methods take a region index to support operations with multiple regions. The control flow between regions can already be queried using the RegionBranchOpInterface, so it is not modeled here.

We make two simplifying assumptions here.

  1. Input and output dependencies are combined into single tokens. This effectively means that regions act as a synchronization point in case of multiple dependencies. We might want to relax this but it creates extra complications.
  2. Dependencies are not control flow dependent. We do not allow cases where the forwarded dependencies differ depending on which branch is taken. This manifests in the lack of modelling successors. The idea behind this is that we want to keep the ordering constraints between operations static.

We propose to evolve this in tree, implementing an initial commit containing the two interfaces and the token type. Additionally, we would like to start using this model for GPU operations.

This looks great, thanks Stephan.

I was wondering about the AsyncRegionOpInterface interface names. Wouldn’t they be more akin to AsyncOpInterface names if they were (basically swapped):

// returns %t
BlockArgument getAsyncRegionDependency(int index);
// returns [%9]
ValueRange getAsyncRegionTokens(int index);

I don’t know if it’s feasible to add, but would a RegionInterface be a better fit here?

Do I understand correctly that the async.region op would implement both interfaces?

Our approach has been to model different ‘execution resources’ (i.e. a device) as an object in MLIR, rather than as attributes. I’ll try to sketch out our ideas, which are not completely fleshed out and definitely still has some holes. So we have:

%event_out = equeue.launch (…) in (%event_in,%device) { some ops }

As with your dialect, events make explicit the dependencies between different launches. Our code generation idea is that launch can be lowered into a runtime call that pushes the region onto the ‘event queue’ of the given device. The device has all of the data dependence requirements explicitly represented in events. All data values passed into the region go through the launch (i.e. the launch operation is IsolatedFromAbove. Launches can also be nested, allowing for some design hierarchy where a device may launch work running on other devices.

We also explicitly model different physical memories using a ‘buffer’ concept. Buffers are alot like memrefs with an address space that comes from the physical memory they are allocated in. We use special operations with event dependencies to access buffers.

%buffer = equeue.alloc(%mem, shape) %eventout, %value = equeue.read(%eventin, %buffer) %eventout = equeue.write(%eventin, %buffer, %value)

I think it’s also possible to have another ‘buffer future’ concept, but I think this is a higher level concept that can be lowered into explicit buffers and events as above. I don’t think we’ve really settled on whether having explicit events or implicit events associated with a ‘buffer future’. We also have some special launch-like operations which execute on ‘DMA’ resources. These operations copy from one buffer to another buffer.

%eventout = equeue.memcpy(%eventin, %buffer1, %buffer2, %dma)

In our usage, we’ve primarily been interested in modeling performance of these systems, so modeling different ‘execution resources’ is important because when a resource is ‘busy’, it cannot process other launches. But we can model some things like:

%eventout1 = equeue.memcpy(%eventin, %buffer1, %buffersource, %dma1)
%eventout2 = equeue.memcpy(%eventin, %buffer2, %buffersource, %dma2)

Here 2 different dmas could execute concurrently copying data from a source buffer into two work buffers. This is basically your gpux.transfer. However if they used the same resource:

%eventout1 = equeue.memcpy(%eventin, %buffer1, %buffersource, %dma1)
%eventout2 = equeue.memcpy(%eventin, %buffer2, %buffersource, %dma1)

Then they are guaranteed to not execute concurrently, but either order is possible.

I think the key question is if we have:
1: %event = async.region(%memref) {op1}
2: asynch.region(%event, %memref) {op2}
3: asynch.region(%event, %memref) {op3}

It is clear that op1 must complete before op2 or op3. In fact, the ‘enqueuing’ operation for (1) must also happen before the enqueuing operation for (2) or (3). We would like to allow op2 or op3 to occur in either order. So we have a potential ‘transformation’ that could exchange (2) and (3) and generate:

1: %event = async.region(%memref) {op1}
2: asynch.region(%event, %memref) {op3}
3: asynch.region(%event, %memref) {op2}

Is this legal? Under normal circumstances, it would not be, because asynch.region could have side effects on %memref. I wonder instead of having an interface that says the first argument is an ‘asynchronous dependency’ instead we have an interface that instead describes the arguments with memory semantics but that can be reordered around other uses?

Steve

Can this token be in not-ready state? Or this is always ready async block argument, because region acts as a synchronization point on input tokens?

I feel that memref in the return type is a big lie, because it’s not really a memref, it is a promise that a memref will be available at some point in future.

Maybe model this promises with async.value? And async.region will be the operation that can unwrap async arguments, and pass “real values” to attached regions.

%0 = ... : !async.value<memref>
%1 = ... : !async.value<memref>

%2 = async.region [] ( {
  %3 = "gpux.alloc"() : () -> (memref<1024xf32, 5>) 
  async.return %3 : memref<1024xf32, 5>
}) : () -> !async.value<memref<1024xf32, 5>>

%4 = async.region [%0, %3] { 
  ^bb0(%v0 : memref, %v3 : memref<1024xf32, 5>)
  "gpux.transfer"(%v0, %v3): (memref, memref) -> ()
  async.return
} : (!async.value<memref>, !async.value<memref<1024xf32, 5>>) -> !async.token

In this case every other op doesn’t have to know anything about async, async tokens and values are only used by ops from the async dialect.

Op async dependencies can be queried from the closest async region parent.

Thanks for the context. Some clarifying questions and thoughts inline.

So here eventin would be the async value that specifies the dependencies (mapping this back to the async model) while dma1 and dma2 serve at least two different purposes.

One is an encoding that drives your resource or device assignment, which in this case is dynamic if I understand right. So at compiletime, you do not know whether they will use the same device unless you can statically prove that the two resources are the same. Does that reflect your use case?

The second purpose is modelling mutual exclusion, so the resources also act as a mutex. I have not had that in mind, at all, when thinking about async.

Would it be possible to decouple these different uses? Or does that not make any sense in your usecase?

Sideeffects are interesting indeed. I have not thought about reordering much, as it was not critical for my immediate use case. Thanks for bringing it up.

You could define your async.region variant as side-effect free (which you can, even if it contains side-effecting operations itself) which would then allow it to be reordered arbitrarily as long as the token dependencies are honored. I think that would also mean that it might be removed if its result is unused. Typically, at least the async token should have a user, as otherwise there would be nothing that depends on the op executing.

This would assume that all dependencies are modeled with tokens. If we wanted a partial model, then the op should still have some side-effects. Those could be on select operands, as you suggest, or we could allow an operation to filter the side-effects of the ops it contains. I think currently an op cannot aggregate the memory effects of operations it contains.

The async region does not act as a synchronization point.

The token can be in not-ready state but async operations that depend on it are required to only run once it is complete. So for the gpu example, these operations may already enqueue their action (like enqueue an allocation on the device) and complete even if the async block argument is not ready. However, the action on the device then has to wait for the token to become ready.

If you want to mix asynchronous and synchronous code, then you would need a different operation for that. It would not have an inner token to depend on, though, as there is no asynchronous execution in its region.

It is a memref, you are just not allowed to use it quite yet :slight_smile:

I get your point though, and it is related to what @mehdi_amini asked during the presentation wrt. the tensor level modelling.

This works well in a model where the nested operations within the region are synchronous. As soon as you start nesting operations, this unwrapping no longer is meaningful because the nested, asynchronous operations would require an async.value as well.

So I would not require it from the perspective of the async dialect. However, if we make the fact that something is a token a type interface, then both async.token and async.value could be token types that are used to model dependencies. Async aware analyses would then query that aspect.

When designing your dialect, you could further have both unwrapping (have a region but no region token) and not unwrapping (have a region with a token) operations.

I think of the resource assignment in the equeue as being ‘visible to the program’ and because of complex control flow, might be difficult to statically determine exactly. In contrast, (If I understand correctly) an async.region can be executed on any resource at execution time, as determined by some scheduling algorithm in a runtime library.

I thought about declaring the operation as ‘SideEffectFree’, but this seems like it is not accurately describing the operation and leads to other difficulties (like dead code elimination). I think that it’s worthwhile considering what transformations can be made by a ‘runtime scheduling’ and provide enough information to the compiler where it could make similar code transformations safely.

It’s not clear to me what transformations/infrastructure would make use of the getAsyncDependency() APIs you propose. I’m not wild about transformations gaining a special case of

if (auto op = dyn_cast<AsyncOpInterface>) {
  auto dependencies = op.getAsyncDependencies();

I’m supportive of looking more into this direction as well.

Alternatively to having the region doing the unwrapping, we can have an explicit operation for the unwrapping (I suspect we’d need this anyway?). The region operand list would model only the dependency tracking for a whole region scheduling:

%0 = ... : !async.value<memref>
%1 = ... : !async.value<memref>

%token, %future = async.region [] ( {
  %3 = "gpux.alloc"() : () -> (memref<1024xf32, 5>) 
  async.return %3 : memref<1024xf32, 5>
}) : () -> !async.token, !async.value<memref<1024xf32, 5>> // Note: region always return a token

// this region does not wait on anything
%4 = async.region { 
  ... // do something...
  // here you need to wait for the value to be available
  %v0 = async.await %0: memref
  %v3 = async.await %future : memref<1024xf32, 5>
  "gpux.transfer"(%v0, %v3): (memref, memref) -> ()
  async.return
} : !async.token

// This time the whole region waits for the async value to be ready
%4 = async.region [%0, %future] { 
  ... // do something...
  // There won't be any wait since the parent region already synchronizes.
  %v0 = async.await %0: memref
  %v3 = async.await %future : memref<1024xf32, 5>
  "gpux.transfer"(%v0, %v3): (memref, memref) -> ()
  async.return
} : !async.token