New Linalg Code Generation Strategy for Innermost Reductions

Hi all,

At IREE, we are working on improving the generated code of innermost (fastest-varying dimension) reductions and would like to get feedback and give this work some visibility so that other Linalg users can improve reductions in the same way, if they haven’t done something similar already. Our proposal aligns with what LLVM and other production compilers generate for this kind of reductions and should be more performant than the existing alternatives in Linalg today.

Running Example:

We will use the following i8i32 reduction example throughout this proposal for illustration purposes, where the innermost tensor dimension is reduced:

#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0, d1) -> (d0)>

util.global private @__A {noinline} = dense<1> : tensor<384x512xi8>

func.func @main_dispatch_5() -> tensor<384xi32> {
  %c0_i32 = arith.constant 0 : i32
  %0 = util.global.load @__A : tensor<384x512xi8>
  %1 = linalg.init_tensor [384] : tensor<384xi32>
  %2 = linalg.fill ins(%c0_i32 : i32) outs(%1 : tensor<384xi32>) -> tensor<384xi32>
  %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<384x512xi8>) outs(%2 : tensor<384xi32>) {
  ^bb0(%arg2: i8, %arg3: i32):
    %4 = arith.extsi %arg2 : i8 to i32
    %5 = arith.addi %4, %arg3 : i32
    linalg.yield %5 : i32
  } -> tensor<384xi32>
  return %3 : tensor<384xi32>
}

Background

We currently have two strategies to code-generate reductions on the innermost tensor dimension in Linalg: InnerParallel and InnerReduction.

InnerParallel

The InnerParallel strategy transposes the reduction dimension with an outer parallel dimension of the tensor. Vectorization is then applied to the innermost parallel dimension such that each vector lane computes a full independent reduction. No horizontal vector reductions are required by this approach. The following snippet shows the resulting IR after applying the InnerParallel strategy to the running example:

 scf.for %arg0 = %c0 to %c32 step %c8 {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
    %5 = scf.for %arg1 = %c0 to %c512 step %c64 iter_args(%arg2 = %cst) -> (vector<8xi32>) {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
      %6 = vector.transfer_read %4[%arg0, %arg1], %c0_i8 {in_bounds = [true, true]} : memref<32x512xi8, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>, vector<8x64xi8>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
      %7 = arith.extsi %6 : vector<8x64xi8> to vector<8x64xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
      %8 = vector.transpose %7, [1, 0] : vector<8x64xi32> to vector<64x8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
      %9 = vector.extract %8[0] : vector<64x8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
      %10 = arith.addi %9, %arg2 : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
      ...                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
      %135 = vector.extract %8[63] : vector<64x8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
      %136 = arith.addi %135, %134 : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
      scf.yield %136 : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
    }                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
    vector.transfer_write %5, %3[%arg0] {in_bounds = [true]} : vector<8xi32>, memref<32xi32, affine_map<(d0)[s0] -> (d0 + s0)>>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
  }

InnerParallel may offer higher accuracy in floating-point reductions since floating-point elements are reduced in sequential order (assuming that sequential order means higher accuracy). However, transposing the data is very expensive!

InnerReduction

The InnerReduction strategy keeps the reduction dimension in the innermost dimension. Vectorization is applied to the reduction dimension and a scalar accumulator is used along the reduced dimension. Using a scalar accumulator requires performing horizontal vector reductions per iteration of the innermost loop to add the partial reduction of each iteration to the scalar accumulator. The following snippet shows the resulting IR after applying the InnerReduction strategy to the running example:

 scf.for %arg0 = %c0 to %c32 step %c8 {                                                                                                                                                                                                                                                                                                                                                                                                                                                            
    %5 = scf.for %arg1 = %c0 to %c512 step %c64 iter_args(%arg2 = %cst) -> (vector<8xi32>) {                                                                                                                                                                                                                                                                                                                                                                                                        
      %6 = vector.transfer_read %4[%arg0, %arg1], %c0_i8 {in_bounds = [true, true]} : memref<32x512xi8, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>, vector<8x64xi8>                                                                                                                                                                                                                                                                                                                          
      %7 = arith.extsi %6 : vector<8x64xi8> to vector<8x64xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                     
      %8 = vector.extract %7[0] : vector<8x64xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                  
      %9 = vector.extract %arg2[0] : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                  
      %10 = vector.reduction <add>, %8, %9 : vector<64xi32> into i32                                                                                                                                                                                                                                                                                                                                                                                                                                
      %11 = vector.insertelement %10, %cst[%c0 : index] : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                             
      ...
      %36 = vector.extract %7[7] : vector<8x64xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                 
      %37 = vector.extract %arg2[7] : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                 
      %38 = vector.reduction <add>, %36, %37 : vector<64xi32> into i32                                                                                                                                                                                                                                                                                                                                                                                                                              
      %39 = vector.insertelement %38, %35[%c7 : index] : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                              
      scf.yield %39 : vector<8xi32>                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
    }                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
    vector.transfer_write %5, %3[%arg0] {in_bounds = [true]} : vector<8xi32>, memref<32xi32, affine_map<(d0)[s0] -> (d0 + s0)>>                                                                                                                                                                                                                                                                                                                                                                     
  } 

InnerReduction may be more performant than InnerParallel since data transposition is not required. However, horizontal vector reductions are also somewhat expensive and they are used in every iteration of the reduction loop. This approach may not reduce the elements in sequential order so the result of floating-point reductions may be “less accurate”.

Proposal

We want to implement a new reduction strategy that avoids the overhead of transposing data but also minimizes the number of horizontal vector reductions by:

  • Using a vector accumulator along the vector iterations of the reduction dimension instead of a scalar accumulator.
  • Computing a single horizontal vector reduction per reduction, when all the elements have been reduced to the single vector accumulator.

The following snippet shows the potential resulting IR after applying the new strategy to the running example (handwritten IR, be aware of potential mistakes):

    %6 = scf.for %arg1 = %c0 to %c32 step %c8 iter_args(%arg2 = %4) -> (tensor<32xi32>) {
      // Create a new temporary that will hold the vector partial results that
      // will later be horizontally reduced.
      tmp = linalg.init_tensor [8, 16] : tensor<8x16xi32>                                                                                                                                                                        
      %tmp_init = vector.transfer_write %cst8x16, %tmp_ld[%c0, %c0] {in_bounds = [true, true]} : vector<8x16xi32>, tensor<8x16xi32>
      %9 = scf.for %arg3 = %c0 to %c512 step %c64 iter_args(%arg4 = %tmp_init) -> (tensor<8x16xi32>) {
        %11 = vector.transfer_read %5[%arg1, %arg3], %c0_i8 {in_bounds = [true, true]} : tensor<32x512xi8>, vector<8x64xi8>
        %12 = vector.transfer_read %arg4[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : tensor<8x16xi32>, vector<8x16xi32>
        %13 = arith.extsi %11 : vector<8x64xi8> to vector<8x64xi32>
        // Note 'vector.multi_reduction'’s output type. Each partial reduction is not
        // reduced to a scalar but to a 16-wide vector accumulator.
        %14 = vector.multi_reduction <add>, %13, %12 [1] : vector<8x64xi32> to vector<8x16xi32>
        %15 = vector.transfer_write %14, %arg4[%c0, %c0] {in_bounds = [true, true]} : vector<8x16xi32>, tensor<8x16xi32>
        scf.yield %15 : tensor<8x16xi32>
      }
      // A final horizontal reduction reduces the vector accumulator to a scalar.
      %vred = vector.transfer_read %9[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : tensor<8x16xi32>, vector<8x16xi32>
      %hred = vector.multi_reduction <add>, %vred, %cst [1] : vector<8x16xi32> to vector<8xi32>
      %10 = vector.transfer_write %hred, %arg2[%arg1] {in_bounds = [true]} : vector<8xi32>, tensor<32xi32>
      scf.yield %10 : tensor<32xi32>
    }

As noted in the IR, a 16-element vector accumulator is created for each reduction, which is used within the innermost loop to reduce the full reduction dimension to only 16 elements. We took the liberty of using the vector.multi_reduction operation with a vector<8x16xi32> return type to illustrate this idea, even though the return type is not supported right now. After the reduction loop, a single horizontal vector reduction is used after the reduction loop to produce the final scalar result out of the vector accumulator.

This approach should be more performant than InnerReduction and InnerParallel. Accuracy should be on par with *InnerReduction" since the order in which the elements are reduced is similar to some extent. Extra accuracy might be achieved if we enforced horizontal vector reductions to be computed in sequential order, at the expense of some performance.

Implementation

We plan to follow the spirit of other transformations in Linalg by decomposing the new reduction approach into simple and composable transformations. For simple cases, we will apply the following steps:

  1. Split the input linalg.generic op into two linalg.generic ops with the existing split reduction utility (thanks for the suggestion, @ThomasRaoux!). They will model the loop performing the reduction using a vector accumulator and final horizontal vector reduction, respectively. The following snippet illustrates how the running example would like like after such split (handwritten example, not the actual output of the split reduction utility):
      %11 = scf.for %arg3 = %c0 to %c512 step %c64 iter_args(%arg4 = %10) -> (tensor<8x16xi32>) {
        %12 = tensor.extract_slice %5[%arg1, %arg3] [8, 64] [1, 1] : tensor<32x512xi8> to tensor<8x64xi8>
        %13 = tensor.cast %12 : tensor<8x64xi8> to tensor<*xi8>
        %14 = tensor.cast %13 : tensor<*xi8> to tensor<8x32x16xi8>
        // First generic on: partial reduction on vector accumulator.
        %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                                               affine_map<(d0, d1, d2) -> (d0, d2)>],
                              iterator_types = ["parallel", "reduction", "parallel"]}
          ins(%14 : tensor<8x32x16xi8>) outs(%10 : tensor<8x16xi32>) {
          ^bb0(%arg5: i8, %arg6: i32):
            %16 = arith.extsi %arg5 : i8 to i32
            %17 = arith.addi %16, %arg6 : i32
            linalg.yield %16 : i32
        } -> tensor<8x16xi32>
        scf.yield %15 : tensor<8x16xi32>
      }
      // Second generic op: final horizontal vector reductions.
      %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                             affine_map<(d0, d1) -> (d0)>],
                            iterator_types = ["parallel", "reduction"]}
        ins(%11 : tensor<8x16xi32>) outs(%8 : tensor<8xi32>) {
        ^bb0(%arg4: i32, %arg5: i32):
          %16 = arith.addi %arg4, %arg5 : i32
          linalg.yield %16 : i32
      } -> tensor<8xi32>
  1. Vectorize both generic ops with the InnerReduction approach, which will lead to vertical vector operations for the outer dimension reduction in the first linalg.generic op and horizontal vector reductions for the innermost dimension in the second linalg.generic op. No major changes would be needed in the vectorizer. The resulting code would be similar to the one shown in Proposal.

As part of this work, we will also extend the Linalg’s Code Generation Strategy infrastructure to support this new reduction approach.

Thanks! Any feedback would be appreciated!

@nicolasvasilache, @ThomasRaoux, @MaheshRavishankar, @vmurali

2 Likes

+1, this looks great to me. The memory access patterns look much better this way, for GPU kind of targets this also allows getting extra parallelism out of the op.

The existing split reduction transformation should work for this. Another way that has been discussed would be to implement this same transformation as a tiling transformation materializing the new parallel loop. @pifon2a had implemented it some time back however it was based on tiled.loop ops that were later removed.
It could be interesting to add support to the tiling interface to be able to generate the split reduction (this was suggested by @MaheshRavishankar) and generate different kind of loops based on uses.

SplitReduction followed by horizontal reduction LGTM.

FWIW it’s also how we are approaching targeting warp shuffle reductions on GPU.

As part of this process, can we also retire InnerReduction?

Thanks Diego for the proposal. I understand that using existing split-k is the easiest path to explore. So +1 to do this exploration.

Not related to this RFC itself, but the split-K as implemented today has a limitation that the the split chosen divides the reduction dimension. Today these have to be statically provable cause of an implementation detail, but fundamentally the restriction has to be satisfied dynamically. Without this restriction you cannot represent the “split reduction” part as a single generic operation. We can live with this restriction for the time being, but that is limiting in general.

The way to address this restriction IMO is to really see split-K transformation as a tiling + distribution of the reduction dimension. So you can start with

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %input, %c0 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0] : tensor<?xf32>
%cst = arith.constant 0.0 : f32
%out = linalg.fill ins(%cst : f32) outs(%init : tensor<?xf32>) -> tensor<?xf32>
%result = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%input : tensor<?x?xf32>) outs(%out : tensor<?xf32>) {
    ^bb0(%b0 : f32, %b1 : f32) :
      %addf = arith.addf %b0, %b1 : f32
      linalg.yield %addf : f32
    } -> tensor<?xf32>

As you would do any other tiling operation, you would tile the reduction dimension and represent the inter-tile iteration space using a loop (unlike the current split-K implementation that puts the tiled iteration space into a linalg.generic)

%splitkVal = ... : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %input, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %input, %c1 : tensor<?x?xf32>
%splitKDim = affine.map<()[s0, s1] -> ceilDiv(s0, s1)>()[%d1, %splitkVal]
%splitKInit = linalg.init_tensor [%d0][%splitKDim] : tensor<?x?xf32>
%splitkFill = linalg.fill ins(%cst : f32) outs(%splitKInit : tensor<?x?xf32>) -> tensor<?x?xf32>
%r0 = scf.for %iv0 = %c0 to %d1 step %splitkVal iter_args(%arg0 = %splitKFill)) -> tensor<?x?xf32> {
  %kVal = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%iv0)[%d1, %splitkVal]
  %slice_in = tensor.extract_slice %input [0, %iv0] [%d0, %kVal] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
  %iv0_num = affine.min affine_map<(d0)[s0] -> floorDiv(d0, s0)>(%iv0)[%splitKVal]
  %slice_out = tensor.extract_slice %arg0 [0, %iv0_num] [%d0, 1] [1, 1] : tensor<?x?xf32> to tensor<?xf32>
  %result_slice = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%slice_in : tensor<?x?xf32>) outs(%slice_out : tensor<?xf32>) {
      ^bb0(%b0 : f32, %b1 : f32) :
        %addf = arith.addf %b0, %b1 : f32
        linalg.yield %addf : f32
      } -> tensor<?xf32>
  %yield = tensor.insert_slice %result_slice into %arg0[0, %iv0_num] [%d0, 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
  scf.yield %yield : tensor<?x?xf32>
}
%init = linalg.init_tensor [%d0] : tensor<?xf32>
%cst = arith.constant 0.0 : f32
%out = linalg.fill ins(%cst : f32) outs(%init : tensor<?xf32>) -> tensor<?xf32>
%result = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1) >, affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%r0 : tensor<?x?xf32>) outs(%out : tensor<?xf32>) {
      ^bb0(%b0 : f32, %b1 : f32) :
        %addf = arith.addf %b0, %b1 : f32
        linalg.yield %addf : f32
      } -> tensor<?xf32>

The above code is very similar to tiling of the original linalg.generic with two modifications

  1. The use of %splitKFill effectively is an array privatization followed by hoisting
  2. The prologue to reduce the result of tiling to get the final value.

This loop code is essentially similar to the code here under Proposal except that that example is doing the same thing at vector level.
Also note that the loop above is parallel, and could be distributed (or you could use a loop construct that encodes the parallelism like scf.for_each, etc.)

The loop code above I think can be generated using the TilingInterface by adding a couple of InterfaceMethods that

  1. Get the identity element of the reduction operation
  2. Get the code to implement the final reduction step.

So far I havent considered using the TilingInterface for generating vectorization code. That seems like a stretch, but it should be OK to duplicate the transformations done to generate vectorized code with loops as well.

I’m not sure if that would work in our case since tiling has already happened before we come to a point where we have to split but we can definitely loop into that.

How is the split reduction transformation triggered in this context?

Yes, we can evolve InnerReduction to what we need for the split approach or create a new strategy and get rid of InnerReduction if it doesn’t have value anymore.

Yes, I focused on describing what is needed for the simplest use case but this approach should also be able to deal with more complex cases. For those cases, peeling and masking should come into play. Peeling would make the reduction dimension multiple of the split, and for masking, I expect that we have to extend the split reduction utility to deal with dynamic shapes by providing tile/vector sizes as input parameters (we are doing something similar for masking in the vectorizer). The code generation driver plays a critical role in communicating the tile/vector sizes and other vectorization decisions among these independent transformations. That’s why we discussed bringing some of these vectorization related transformations to the vectorizer in the long term. This RFC, though, is unrelated to that but it illustrates some of the dependencies that we have between vectorization related transformations.

If I understand correctly, the discussion here is about who should introduce the reduction privatization for vectorization (vector accumulator) and the code for the final horizontal vector reduction. You argue that they should be introduced as part of tiling the reduction dimension. That makes sense to me but tiling the reduction dimension seems like an orthogonal problem to vectorization. Yes, we currently tile so that we end up with static shapes but, as you mentioned, moving more vectorization pieces to tiling might be a stretch. We should be able to generate the code that we want reductions with static shapes without any tiling (and for dynamic shapes, with masking), right? It also seems like moving vectorization related transformations even farther away from the vectorizer so it’s no clear to me how this approach will compose with peeling, for example. In any case, let’s discuss how to support these complex cases once we have the simple ones in place and have built more expertise. Masking will also bring more flexibility when it comes to dealing with dynamic shapes so we may have more options to consider.

I am saying that the reduction privatization is applicable both for Linalg/Tiling level and Vector level and code generated for both are similar. My main point is that split-K which is also doing reduction privatization in some sense has some limitations that we can address at Linalg/Tiling level.
Also I think reduction privatization has value both at Linalg/Tiling level and at Vector level. That doesnt mean they have to share the implementation in any way. I am recommending not sharing implementation even though they are doing similar things. They are doing similar things at different levels of the stack but with different goals. So its fine to have separate implementations for both.

Thanks for clarifying. I need to better understand the existing limitations and the role that peeling and masking will play there but your suggestion sounds reasonable to me.

Is this also implemented in MLIR rapo or only in IREE? If it is only in IREE, Can I move it to MLIR?