[RFC] Add std.atomic_rmw op

Hi everyone,

I’ve recently added llvm.atomicrmw and llvm.cmpxchg and this proposal aims to add std.atomic_rmw. It’s mostly mechanical, but there’s an open question and adding an op to the standard dialect seems like having a review is appropriate.

Thanks!

[RFC] Add std.atomic_rmw op

Background

Atomic read-modify-write blocks have useful semantics beyond the LLVM dialect. In conjunction with affine.parallel_for, for example, they can be used to represent common machine learning reduction operations such as convolution or pooling. We believe including atomic_rmw in the standard dialect will make these operations available at a more appropriate level of abstraction than llvm.atomicrmw. This also enables us to represent atomic RMW’s using operations not available in LLVM’s atomic RMW by providing a lowering into llvm.cmpxchg.

Goals

  • Add an op to represent atomic read-modify-write sequences to the Standard dialect.
  • Add lowering from std.atomic_rmw into the appropriate op in the LLVM dialect.
    • llvm.atomicrmw for trivial cases.
    • llvm.cmpxchg for complex cases.

Proposal

IR Representation

def AtomicRMWOp : Std_Op<"atomic_rmw"> {
  let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
  let regions = (region SizedRegion<1>:$body);
}

def AtomicRMWYieldOp :
    Std_Op<"atomic_rmw.yield", [HasParent<"AtomicRMWOp">, Terminator]> {
  let summary = "terminator for atomic_rmw operations";
  let arguments = (ins AnyType:$result);
}

Lowering into llvm.atomicrmw

Pattern matching can be used to determine how to lower a particular std.atomic_rmw op. Any trivial body that has a single op that matches one of the AtomicBinaryOp enum values will be lowered directly into llvm.atomicrmw.

For example:

def @sum(%memref : memref<10xf32>, %i : index, %val : f32) {
  atomic_rmw %iv = %memref[%i] : memref<10xf32> {
    %local = addf %iv, %val : f32
    atomic_rmw.yield %local : f32
  }
}

Lowers into:

!memref_ptr = type !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
!memref_val = type !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">

llvm.func @sum(%memref: !memref_ptr, %i : !llvm.i64, %val: !llvm.float) {
  %load = llvm.load %memref : !memref_ptr
  %buf = llvm.extractvalue %load[1] : !memref_val
  %ptr = llvm.getelementptr %buf[%i] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
  llvm.atomicrmw fadd %ptr, %val acq_rel : !llvm.float
}

Lowering into llvm.cmpxchg

All other lowerings make use of llvm.cmpxhg. For example, to lower
a floating point max reduction:

func @max(%memref : memref<10xf32>, %i : index, %val : f32) {
  atomic_rmw %iv = %memref[%i] : memref<10xf32> {
    %cmp = cmpf "ogt", %iv, %val : f32
    %max = select %cmp, %iv, %val : f32
    atomic_rmw.yield %max : f32
  }
  return
}

Lowers into:

!memref_ptr = type !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
!memref_val = type !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">

llvm.func @max(%memref: !memref_ptr, %i : !llvm.i64, %val: !llvm.float) {
  %load = llvm.load %memref : !memref_ptr
  %buf = llvm.extractvalue %load[1] : !memref_val
  %ptr = llvm.getelementptr %buf[%i] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
  %init_loaded = llvm.load %ptr : !llvm<"float*">
  llvm.br ^loop(%init_loaded : !llvm.float)
^loop(%loaded: !llvm.float):
  %cmp = llvm.fcmp "ogt" %loaded, %val : !llvm.float
  %max = llvm.select %cmp, %loaded, %val : !llvm.i1, !llvm.float
  %pair = llvm.cmpxchg %ptr, %loaded, %max acq_rel monotonic : !llvm.float
  %new_loaded = llvm.extractvalue %pair[0] : !llvm<"{ float, i1 }">
  %success = llvm.extractvalue %pair[1] : !llvm<"{ float, i1 }">
  llvm.cond_br %success, ^end, ^loop(%new_loaded : !llvm.float)
^end:
  llvm.return
}

Using the following logic:

      +---------------------------------+
      |   <code before the AtomicRMWOp> |
      |   <compute initial %iv value>   |
      |   br loop(%iv)                  |
      +---------------------------------+
             |
  -------|   |
  |      v   v
  |   +--------------------------------+
  |   | loop(%iv):                     |
  |   |   <body contents>              |
  |   |   %pair = cmpxchg              |
  |   |   %ok = %pair[0]               |
  |   |   %new = %pair[1]              |
  |   |   cond_br %ok, end, loop(%new) |
  |   +--------------------------------+
  |          |        |
  |-----------        |
                      v
      +--------------------------------+
      | end:                           |
      |   <code after the AtomicRMWOp> |
      +--------------------------------+

Open Questions

AtomicOrdering

This proposal uses the AtomicOrdering::acq_rel value for both the trivial llvm.atomicrmw lowering and the success ordering for llvm.cmpxchg. It also uses AtomicOrdering::monotonic for the failure ordering of llvm.cmpxchg.

Are these the proper orderings for this operation? Should an AtomicOrdering be exposed via the std.atomic_rmw op? If so, which orderings should be exposed?

Future Work

  • Lowering of std.atomic_rmw into upcoming OpenMP dialect.
  • Lowering of std.atomic_rmw into GPU dialect.
  • Determine if it makes sense for std.atomic_rmw to be used within the body of a loop.parallel.
1 Like

Ping :slight_smile: I haven’t seen any comments here so I’ll go ahead and start submitting patches if that’s alright with the community. I plan to use this op as a way to lower an upcoming affine.atomic_rmw op. Ideally this op will never actually get used as it should be optimized away, but it seems good to have a ‘correct’ path that we can initially lower into.

Thanks!

Could you describe in more detail what are the limitations on the body of the atomic block? We’d need to make sure that the operations can indeed be performed atomically, not only the last cmpxchg part. For example, accessing the same value as the atomic itself within its body is a clear no-go

atomic_rmw %arg0 = %0[%1, %2] : memref<?x?xf32> {
  %3 = constant 42.0 : f32
  store %3, %0[%1, %2] : memref<?x?xf32> // Oops...
  %4 = constant 4242.0 : f32
  atomic_rmw.yield %4 : f32
}

I don’t think we can write a verifier for the general case (e.g., an unregistered dialect might have an operation that accesses the memref through some closure in which another operation preceding the atomic captured a pointer to the memref element in question), so you’ll have to undefine the behavior in some cases.

This proposal uses the AtomicOrdering::acq_rel value for both the trivial llvm.atomicrmw lowering and the success ordering for llvm.cmpxchg .

This sounds reasonable. We don’t seem to have other users of atomic operation atm, so I’m fine with leaving a note in the doc and implementation with the specific rationale of choosing this order, and making it we can further revise this decision when new use cases appear. We have such cases already.

  • Lowering of std.atomic_rmw into GPU dialect.
  • Determine if it makes sense for std.atomic_rmw to be used within the body of a loop.parallel .

+ @herhut on both of these.

I am not sure we want a separate atomic in the GPU dialect, we can reuse the standard one unless there is a compelling reason to introduce GPU-specific semantics.

What do you think will be the problem with atomic inside the body of the parallel loop? In general, MLIR wants dialects to compose as much as possible, so restricting the set of operations where another operation can appear should not be the first choice. E.g., somebody may have their own flavor of parallel loops in an out-of-tree dialect, should they use std.atomic_rmw or roll their own?

I think you’re right about the general case, but it seems like we can at least have a simple verification that all ops within the body are NoSideEffect (or whatever is appropriate in the new effect system coming). Basically any memory access in the body is probably not a useful thing to allow.

Oops, I meant to say lowering into SPIR-V. I agree that the GPU dialect probably doesn’t need any new op to support this.

This point was actually my misunderstanding of the difference between loop.reduce and std.atomic_rmw. Initially I was wondering if they were redundant, or if one could subsume the other. But after discussing with @jbruestle, I think both ops are useful and have slightly different semantics. While loop.reduce can be used to reduce into a scalar (value-like), the std.atomic_rmw can be used to reduce into a memref (buffer-like). Thus it seems like having std.atomic_rmw inside the body of a loop.parallel is a good use case.

Makes sense to me, thanks! Could you some text along these lines to the op documentation?

I’m making a difference between “this is useless” and “this should not be allowed”. Even if there is no good use case for atomics in a parallel loop, but the semantics allows them to be there, we should not connecting one to another in any way. The way I am looking at it is that there may be other dialects that have their own flavor of parallel loops and, if there is a semantic problem between atomics and loop parallelism, is should be made explicit in such a way that maintainers of those ops could decide whether to allow atomics in them or not.

This is a really interesting design, I like how it provides a generalized notion for atomic_rmw, merging compare and exchange into a single abstraction. I could also imagine how certain cases could be lowered directly into load-linked store-conditional pairs for architectures that support them.

That said, this makes atomic_rmw much more verbose for the common case. Do you need the generality of this right now?

Another approach to consider: you could introduce two ops: a more traditional atomic_rmw that takes an enum, as well as a atomic_rmw_general that works exactly like you describe. The advantage of having two of them is that you could make the later one canonicalize towards the former in the simple case (which is what most passes would want to work on and reason about), without losing the generality. The disadvantage is a bit more complexity in the design.

WDYT?

-Chris

Oops, I meant to say lowering into SPIR-V. I agree that the GPU dialect probably doesn’t need any new op to support this.

I see SPIR-V was mentioned so drive-by comments. :slight_smile: Curious, do you have a specific task to achieve by going to SPIR-V or is it just for demonstrating lowering of this op? (Either way is very welcome for sure!)

Actually yes, we are very much interested in the SPIR-V dialect so that the PlaidML compiler can lower to GPUs. We’ve been waiting for the Vulkan runtime to mature a little more; we can help improve and maintain this over time. Currently we’re focusing on CPU targets, but we also have another team working in the background on the GPU target. We have a bunch of optimizations we’d like to port over from our older version of PlaidML. We plan to use the affine dialect to do much of the heavy lifting and then lower to SPIR-V at some point. Our earliest versions of our pre-MLIR compiler did a reasonable job of getting to cuDNN performance (+/- some cases), this was even though we were using OpenCL to reach the HW (i.e. it wasn’t using CUDA, although we did have a CUDA backend at one point). We’re very interested in not having to go thru a shader language and go directly to SPIR-V.

So to answer your question, the overall task is to be able to run full networks from our EDSL → Tile dialect → Affine dialect → SPIR-V. We have a working keras backend that is written using the EDSL. We have end to end functionality today with MLIR on just the CPU backend (no optimizations, just purely lowering to LLVM dialect) and we’re currently working on porting over our passes to MLIR.

We probably don’t need the general case yet, our old compiler had a closed set where we had support only for the following (across different data types):

  • add
  • max
  • min
  • mul

I do see value in using the concise/specific version. Now I’m imagining an enum that would specify the supported operations for the std.atomic_rmw. What should this enum be called? I was thinking ReductionKind, but maybe it should be more like AtomicRMWInnerOp or some such. I do somewhat prefer a more generic enum name so that it can be used in different dialects (I’d like to introduce an affine.atomic_rmw that would use the same enum).

I can think of two reasons to prefer a single op for now instead of doing two, beyond the one you mentioned of increased complexity:

  1. The generic implementation might miss some important use case and would need to be extended anyways somehow when that generic use case comes along (this is true even now of the current proposal).
  2. If it doesn’t get used there might be some bit-rot; I usually try to avoid adding new code to reduce this liability.

I think I’m starting to lean towards the specific closed set version only, since that will satisfy my immediate needs and does reduce some of the complexity of the LLVM lowering. I guess part of the reasoning for starting with the generic one was the intuition that the standard dialect has a wider audience and so generic felt more appropriate.

I’m not sure what the best name is, perhaps something like AtomicRMWReductionKind or AtomicRMWOp::ReductionKind? Keep in mind that we’d also like to have somewhat looser guarantees for some kinds of non-reassociative reductions like +.

I’m +1 on progressive and demand driven design. Demand driven design means that you start with a constrained thing and generalize when there is a concrete need to. Waiting until the need arises allows you to build and learn from the experiences with the simpler system, and weight the complexity of more advanced solutions based on that experience. It is much easier to “make a simple system more complex” than it make a “complex solution simpler” when clients and constraints are built upon it.

-Chris

Could we put it somewhere in Rationale.md as a general design principle? :slight_smile:

There are tradeoff between YAGNI and “lack of design” / “not looking beyond the immediate problem”, I’d be careful about stating this as an absolute principle in a rational document without a way to not make this that absolute. I think we had this discussion internally between finding a “North Star” and “working towards milestones” at some point, and this reminds me a bit of this.

What if we called it an AtomicRMWKind? For example:

def ATOMIC_RMW_KIND_ADD    : I64EnumAttrCase<"add", 0>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 1>;
def ATOMIC_RMW_KIND_MAX    : I64EnumAttrCase<"max", 2>;
def ATOMIC_RMW_KIND_MIN    : I64EnumAttrCase<"min", 3>;
def ATOMIC_RMW_KIND_MUL    : I64EnumAttrCase<"mul", 4>;

def AtomicRMWKindAttr : I64EnumAttr<
    "AtomicRMWKind", "",
    [ATOMIC_RMW_KIND_ADD, ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAX,
     ATOMIC_RMW_KIND_MIN, ATOMIC_RMW_KIND_MUL]> {
  let cppNamespace = "::mlir";
}

And then the AtomicRMWOp would be defined as:

def AtomicRMWOp : Std_Op<"atomic_rmw"> {
  let summary = "atomic read-modify-write operation";
  let arguments = (ins
      AtomicRMWKindAttr:$kind,
      AnyMemRef:$memref,
      Variadic<Index>:$indices);
  let results = (outs AnyType:$res);
}

Other naming ideas for the enum:

  • SemiAssociativeOp
  • AtomicKind

The following patch should probably be re-reviewed due to changes based on feeback here: https://reviews.llvm.org/D74401

It looks like we have an example of lowering where the originally proposed generic version of AtomicRMWOp would be useful. TF SelectAndScatterOp in LHLO dialect (link) has scatter region, which after lowering should be applied to RMW a memref element. It may happen that several threads are reading/writing the same index.

Are there any plans to generalize AtomicRMWOp?

Do you have some example pseudo-code for the body of your example use case? I’m wondering if it can be structured to use one of the defined enums.

This is the scatter region of SelectAndScatterOp:

  ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
      "xla_lhlo.add"(%lhs, %rhs, %out) :
          (memref<f32>, memref<f32>, memref<f32>) -> ()
      "xla_lhlo.terminator"() : () -> ()
  })

it gets transformed into

loop.parallel
  %iv_1, ..., %iv_N = <magic_code>
  %old_value = load %out[%iv_1, ..., %iv_N] 
  ...
  %old_value_buf = alloc() : memref<f32>
  store %old_value, %old_value_buf [] : memref<f32>
   "xla_lhlo.add"(%element_to_add, %old_value_buf, %scatter_buf)
     : (memref<f32>, memref<f32>, memref<f32>) -> ()
 %scatter_result = load %scatter_buf[] : memref<f32>
 store %scatter_result, %out[%iv_1, ..., %iv_N]

Note, the scatter region could be already lowered to STD and contain std.load, std.addf, std.store ops inside.

This particular code, of course, can be massaged into AtomicRMWKind::addf. However, it seems to me that if we had an operation that would allow users to specify the region, then the lowering to llvm.cmpxchg, which is already implemented for AtomicRMWKind::min/max, would work out of the box. If the users want to lower to llvm.atomicrmw, they would provide the necessary passes.

I’m thinking we could extend the AtomicRMWKindAttr to include a ATOMIC_RMW_KIND_GENERIC case. We could add a region to the op that is only valid if this enum kind is selected.

We are actively using the enum/closed design right now, but adding a generic case/region would allow for both styles in a compatible way. I think lowering the generic into cmpxchg should be pretty easy.