[mlir][RFC] Concurrent updates with `scf.foreach_thread

[RFC] Concurrent updates with scf.foreach_thread.

ForeachThreadOp is the only operation in MLIR Core that models parallel
execution on tensors. It has a perform_concurrently terminator that
specifies how to combine the partial results of all parallel invocations
into a full value, in some unspecified order.

%result = scf.foreach_thread (%threadId) in (%numThreads)
    shared_outs(%out_ = %out) -> (tensor<64xf32>) {
  %offset = affine.apply ...
  %size = affine.min ...

  %inTile = tensor.extract_slice %in[%offset][%size][1]
    : tensor<64xf32> to tensor<?xf32>
  %outTile = tensor.extract_slice %out_[%offset][%size][1]
    : tensor<64xf32> to tensor<?xf32>

  %resultTile = linalg.elemwise_unary 
    ins(%inTile : tensor<?xf32>)
    outs(%outTile : tensor<?xf32>) -> tensor<?xf32>
  scf.foreach_thread.perform_concurrently {
    tensor.parallel_insert_slice %resultTile into %out_[%offset][%size][1]
      : tensor<?xf32> into tensor<64xf32>
  }
}

Notice, that tensor.parallel_insert_slice looks like tensor.insert_slice,
but has no results and the construction of the resulting tensor happens implicitly.

Reductions

The abstraction above can model distribution with respect to parallel/batch
dimensions of ops. The problem arises when we want to tile reductions and
similar ops that require accumulation of the partial results.

%init = linalg.init_tensor [128] : tensor<128xf32>
%out = linalg.fill ins(%cst: f32)
                   outs(%init: tensor<128xf32>) -> tensor<128xf32>
%sum = linalg.reduce 
  ins(%in: tensor<128x512xf32>) outs(%out: tensor<128xf32>) 
  dimensions = [1] 
  (%in_elem: f32, %out_elem: f32):
    %sum_elem = arith.addf %in_elem, %out_elem : f32
    linalg.yield sum_elem : f32
  }

linalg.reduce has a region that specifies the combiner operation(s).
After tiling, the body of the resulting loop should contain two important parts.
The first part is the original linalg.reduce op that operates on the tiles of the
original operands. The second part should specify how to combine overlapping
partial results.

By adding an optional region to tensor.parallel_insert_slice, we can model
concurrent read-modify-like updates.

%neutralElement = arith.constant 0.0 : f32

%sum = scf.foreach_thread (%threadId) in (%numThreads)
    shared_outs(%out_ = %out) -> (tensor<128xf32>) {
  // Offset computations based of %threadId.
  %i = ...
  %j = ...

  // Initialize tmp out.
  %initTmp = linalg.init_tensor [128] : tensor<8xf32>
  %outTmp = linalg.fill ins(%neutralElement: f32)
                   outs(%initTmp: tensor<8xf32>) -> tensor<8xf32>

  // Materialize slices.
  %inTile = tensor.extract_slice %in[%i, %j] [16, 8] [1, 1]
    : tensor<128x512xf32> to tensor<16x8xf32>

  // Compute reduction on tiles.
  %sumTile = linalg.reduce 
    ins(%inTile: tensor<16x8xf32>) outs(%outTmp: tensor<8xf32>) 
    dimensions = [1] 
    (%in_elem: f32, %out_elem: f32):
      %sum_elem = arith.addf %in_elem, %out_elem : f32
      linalg.yield sum_elem : f32
    }  
  
  scf.foreach_thread.perform_concurrently {
    tensor.parallel_insert_slice %sumTile into %out_[%i] [8] [1]
      // Combine reduced tile with the values in %out.
      acc (%new: tensor<8xf32>, %current: tensor<8xf32>) {
        %combined = linalg.generic {
            indexing_maps = [#id_1d, #id_1d],
            iterator_types = ["parallel"]}
            ins(%new: tensor<8xf32>)
            outs(%current : tensor<8xf32>) {
        ^bb(%new_elem: f32, %current_elem: f32) :
          %s = arith.addf %new_elem, %current_elem : f32
          linalg.yield %s : f32
        } -> tensor<8xf32>
        linalg.yield %combined : tensor<8xf32>
      }: tensor<?xf32> into tensor<?xf32>
    }
  }
}

In this example the accumulator-region contains the linalg.generic op that
combines two block arguments. %new corresponds to the partial
result computed in the current thread. %current corresponds to values of the
current slice stored in the shared output.

ParallelInsertSliceOp changes

We can add an optional region to tensor.parallel_insert_slice OpDef.

let regions = (region VariadicRegion<SizedRegion<1>>:$accumulators);

An alternative is to make this region required and not print it when it’s trivial. Then there is no clear way to distinguish between the case when we don’t have overlapping writes and the case when we do, but we want to do concurrent overwrites. In the latter case, we would need to use atomic writes to avoid tears.

What to lower it to?

It can be lowered to std.atomic_rmw or some vector atomic ops directly or
to an intermediate op in MemRef dialect.

  memref.rmw(%new: memref<8xf32>, %current: memref<8xf32>)
    ^bb0(%new_: memref<8xf32>, %current_: memref<8xf32>) {
      %combined = linalg.generic {
              indexing_maps = [#id_1d, #id_1d],
              iterator_types = ["parallel"]}
              ins(%new_: memref<8xf32>)
              outs(%current_ : memref<8xf32>) {
        ^bb(%new_elem: f32, %current_elem: f32) :
          %s = arith.addf %new_elem, %current_elem : f32
          linalg.yield %s : f32
        }
      memref.yield
  }

TilingInterface

Tiling of reductions would require changes to the TilingInterface.
It should be possible to query the op on whether any accumulation is needed and
also allow the op itseld to populate the accumulator region.

Questions

Do we have any atomic vector operations in MLIR?

Do we need the same changes for the upcoming tensor.parallel_scatter?

@nicolasvasilache, @MaheshRavishankar, @apaszke

Thanks much for this proposal @pifon2a , this has been missing for a while and I am supportive of getting parallel reductions modeled on tensors!

It is likely that different flavors of “subset insertion”-like ops need such an accumulator region to model more general reductions.

Not that I know of but there is memref:: GenericAtomicRMWOp and that should work with (some) vectors.

Do you have an e2e working prototype for this by any chance?

In general, atomic operations are implemented at the hardware level, involve cache line level considerations (invalidation, flushing, consistency etc) and likely won’t scale across multiple operations / multiple cache lines (see e.g. the LLVM spec LLVM Language Reference Manual — LLVM 16.0.0git documentation that seems to reflect this underlying HW constraint). The MLIR memref:: GenericAtomicRMWOp takes a memref and indices to resolve to the address of a single element.

It seems to me the lowering to memref may call for either a more general critical section or need to go directly to loop form over elemental types (at which point vectors may need to be broken into pieces too).

What is your take on this point?

1 Like

Something seems off in the example: the linalg.reduce or any linalg.generic will already carry a “read and modify” semantic. But then the suggested changes to tensor.parallel_insert_slice now also carry a “read and modify” component, which is the region and the addf. So it appears that the source elements originally filled and represented by %out that is passed to the shared_outs parameter will be counted twice. If you had filled it with X instead of 0, then your result will be off by X relative to the original IR before transformation.

I think currently you only want to extract from _outs if you want to bufferize in-place. But what you’re going after isn’t an in-place bufferization because if the initial accumulator elements are not zero, then you want to accumulate into a new zero register/buffer in the tile computation region.

Here is current equivalent IR without modifications to the parallel inset slice op:

func.func @test(%arg0: tensor<128x128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> {
  %c16 = arith.constant 16 : index
  %r = scf.foreach_thread (%i, %j) in (%c16, %c16) shared_outs(%outArg = %arg1) -> tensor<128xf32> {
    %offset0 = affine.apply affine_map<(d0)->(d0 * 8)>(%i)
    %offset1 = affine.apply affine_map<(d0)->(d0 * 8)>(%j)
    %tile0 = tensor.extract_slice %arg0[%offset0, %offset1][8, 8][1, 1] : tensor<128x128xf32> to tensor<8x8xf32>    
    %tile2 = tensor.extract_slice %outArg[%offset0][8][1] : tensor<128xf32> to tensor<8xf32>
    %tile = linalg.generic {
      iterator_types = ["parallel", "reduction"],
      indexing_maps = [
        affine_map<(d0, d1)->(d0, d1)>,
        affine_map<(d0, d1)->(d0)>
      ]
    } ins(%tile0 : tensor<8x8xf32>) outs(%tile2 : tensor<8xf32>) {
    ^bb0(%a : f32, %b: f32):
      %add = arith.addf %a, %b : f32
      linalg.yield %add : f32
    } -> tensor<8xf32>

    scf.foreach_thread.perform_concurrently {
      tensor.parallel_insert_slice %tile into %outArg[%offset0][8][1] : tensor<8xf32> into tensor<128xf32>
    }
  }
  return %r : tensor<128xf32>
}

And after `mlir-opt %s --one-shot-bufferize=“bufferize-function-boundaries allow-return-allocs”"

#map0 = affine_map<(d0) -> (d0 * 8)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
module {
  func.func @test(%arg0: memref<128x128xf32, strided<[?, ?], offset: ?>>, %arg1: memref<128xf32, strided<[?], offset: ?>>) -> memref<128xf32, strided<[?], offset: ?>> {
    %c16 = arith.constant 16 : index
    scf.foreach_thread (%arg2, %arg3) in (%c16, %c16) {
      %0 = affine.apply #map0(%arg2)
      %1 = affine.apply #map0(%arg3)
      %2 = memref.subview %arg0[%0, %1] [8, 8] [1, 1] : memref<128x128xf32, strided<[?, ?], offset: ?>> to memref<8x8xf32, strided<[?, ?], offset: ?>>
      %3 = memref.subview %arg1[%0] [8] [1] : memref<128xf32, strided<[?], offset: ?>> to memref<8xf32, strided<[?], offset: ?>>
      linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%2 : memref<8x8xf32, strided<[?, ?], offset: ?>>) outs(%3 : memref<8xf32, strided<[?], offset: ?>>) {
      ^bb0(%arg4: f32, %arg5: f32):
        %4 = arith.addf %arg4, %arg5 : f32
        linalg.yield %4 : f32
      }
      memref.copy %3, %3 : memref<8xf32, strided<[?], offset: ?>> to memref<8xf32, strided<[?], offset: ?>>
    } {thread_dim_mapping = []}
    return %arg1 : memref<128xf32, strided<[?], offset: ?>>
  }
}

Instead I replace the extraction from the %outArg with bufferization.alloc_tensor inside the region, then it will have the wrong semantic (without your modification) but produce:

#map0 = affine_map<(d0) -> (d0 * 8)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
module {
  func.func @test(%arg0: memref<128x128xf32, strided<[?, ?], offset: ?>>, %arg1: memref<128xf32, strided<[?], offset: ?>>) -> memref<128xf32, strided<[?], offset: ?>> {
    %c16 = arith.constant 16 : index
    scf.foreach_thread (%arg2, %arg3) in (%c16, %c16) {
      %0 = affine.apply #map0(%arg2)
      %1 = affine.apply #map0(%arg3)
      %2 = memref.subview %arg0[%0, %1] [8, 8] [1, 1] : memref<128x128xf32, strided<[?, ?], offset: ?>> to memref<8x8xf32, strided<[?, ?], offset: ?>>
      %3 = memref.alloc() {alignment = 128 : i64} : memref<8xf32>
      linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%2 : memref<8x8xf32, strided<[?, ?], offset: ?>>) outs(%3 : memref<8xf32>) {
      ^bb0(%arg4: f32, %arg5: f32):
        %5 = arith.addf %arg4, %arg5 : f32
        linalg.yield %5 : f32
      }
      memref.dealloc %3 : memref<8xf32>
      %4 = memref.subview %arg1[%1] [8] [1] : memref<128xf32, strided<[?], offset: ?>> to memref<8xf32, strided<[?], offset: ?>>
      memref.copy %3, %4 : memref<8xf32> to memref<8xf32, strided<[?], offset: ?>>
    } {thread_dim_mapping = []}
    return %arg1 : memref<128xf32, strided<[?], offset: ?>>
  }
}

It definitely seems clear that replacing the memref.copy with your memref.rmw would achieve the right result. It could be hacked together already by rewriting the memref.copy with the atomics at this point.

@cbate right I expect the bufferization will have to evolve a bit to support either:

  1. direct concurrent updates in the original buffer
  2. local updates to a local buffer then concurrent updates to the original buffer
  3. local updates to a local buffer then synchronized (critical section to original buffer or point-to-point / tree and let the last thread update the original buffer, a-la-warp shuffle)

Which is also why I am curious to see an e2e prototype.
I don’t have a particular bias at this point but would love to see these various schemes appear as we progressively lower and bufferze from higher-level abstractions.

Agreed, it would only work if the output of the reduction is initialized with the neutral element and probably if %out was used instead of %out_. because I want a copy of the original value that was not modified yet. I updated the example IR. Now there is a small init_tensor for the “reduction of the tile”.

It could be hacked together already by rewriting the memref.copy with the atomics at this point.

Yes, on the other hand we need to understand what kind of combiner we should replace memref.copy with. I think it’s easier to have the accumulator region, because it will be constructed during tiling, when there is information available about the combiner.

Not that I know of but there is memref:: GenericAtomicRMWOp and that should work with (some) vectors.

I looked at the definition for memref.generic_atomic_rmw. It might work with vectors, but only we have smth like memref<?xvector<4xf32>> at the arg type.

Not yet. The parallel reductions were pushed further and further away in XLA, because of other priorities. I think the simplest prototype would be a tiled 1D->0D reduction with size 1. In that case, we can emit memref.generic_atomic_rmw during bufferization. Later, memref.rmw could be introduced.

OpenMP is parallel programming with similar concepts. It supports thread-parallelism and reductions. But it doesn’t specify how reductions are implemented. LLVM uses atomic operations or locks. In the middle of MLIR, I would worry more about commutative than implementations details. Different devices will implement reductions differently.

Exactly. The accumulator itself just specifies how the results are combined. But the lowering will be different for different devices. We just need to connect it e2e at least for CPU.

We also model OpenMP reductions in the corresponding dialect, it may be worth taking a look.

Specifically, there are two components that are not present here:

  • a way to specify the neutral element of a reduction;
  • separate regions defining the internals of atomic and non-atomic reduction combinator.

The neutral element feels important. Without it, it is difficult to have more than one accumulator, which limits the freedom for different reduction implementations. We could argue that each part always starts with two elements and has extra logic when there is less than two; this is a trade-off of putting complexity in the lowering vs. in the abstraction. There is also a challenging interplay with atomics because atomics will need memory and will have to initialize it, so we are basically restricted to using atomics with the “out” memory only.

The separation between reductions that have an atomic implementation and the ones that don’t, and therefore require an explicit critical section, may help with the discussion on vectors above. But can also come later.

1 Like

The neutral element does feel important. I was also wondering why we are not using a slice of the input to initialize the output. For example, if we need to reduce 4x8 → 8, we can always have linalg op that takes 3x8 tensor and reduces it onto the output intialized with the first row of the matrix.

We didn’t have an empty tensor value for a long time, so the output tensors, required by the destination passing style, should have been initialized with some value. Now we can.

At the same time, the fact of initializing the output with something removes some options of implementing the reduction. It basically means there is one accumulator and we always reduce into that. So a classical parallel tree reduction implementation becomes contrary to the semantics of the higher-level operation, which is arguably not what we want at this level. We have a choice of ignoring the value of the output tensor (thus “overwriting” it with the reduction result) or including it into the reduction as another element. In the latter case, we shouldn’t prescribe how exactly that element is handled by the reduction, no more than we prescribe how the “regular” elements are handled.