[RFC] Scalable Vectorisation in Linalg

Authors: Andrzej Warzynski (Arm) and Diego Caballero (Google)

Hi everyone,

Scalable vectors are a well-established technology with broad support in multiple components of the LLVM project:

We propose to extend the Linalg vectoriser so that it can also target scalable vectors and leverage the full potential of architectures that support them (see Appendix for examples). This will require changes in the following areas:

  1. Extend tiling so that it can generate tiles with scalable sizes,
  2. Add any missing scalable vector representation to facilitate the changes proposed here,
  3. Teach Linalg’s vectoriser to generate scalable vectors,
  4. Implement the missing lowering from the Vector to the LLVM dialect (and to LLVM IR).

In the sections below we dive into details for each of the above. First, let us start with a bit of background (you can skip it if you’re already familiar with the concept of scalable vectors).

A Brief Primer on Scalable Vectors in LLVM and MLIR

Scalable vectors enable Vector Length Agnostic (VLA) programming. In this model, the developer and the compiler are free from concerns about the actual width of the available vectors (e.g. 128 vs 256 bits). This allows one, for example, to compile once and then execute on CPUs with different vector widths and still utilise full vector width in every case.

Scalable vectors are like regular vectors, but their actual size is e.g. 128 * vscale rather than plain 128 bits. The value of vscale is unknown at compile time, but known at runtime. At the LLVM IR level you can use the llvm.vscale intrinsic to retrieve this value (so that it can be used in various arithmetic operations). And when defining scalable vectors, you would use vscale x 4 instead of 4 when specifying vector length.

Here’s an example of fixed-width vs scalable reductions in LLVM IR:

// Extracted from llvm/test/CodeGen/AArch64/double_reduct.ll
define float @add_f32(<8 x float> %a, <4 x float> %b) {
  %r1 = call fast float @llvm.vector.reduce.fadd.f32.v8f32(float -0.0, <8 x float> %a)
  %r2 = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %b)
  %r = fadd fast float %r1, %r2
  ret float %r
}
 
declare float @llvm.vector.reduce.fadd.f32.v8f32(float, <8 x float>)
declare float @llvm.vector.reduce.fadd.f32.v4f32(float, <4 x float>)

// Extracted from llvm/test/CodeGen/AArch64/sve-doublereduct.ll
define float @add_f32(<vscale x 8 x float> %a, <vscale x 4 x float> %b) {
  %r1 = call fast float @llvm.vector.reduce.fadd.f32.nxv8f32(float -0.0, <vscale x 8 x float> %a)
  %r2 = call fast float @llvm.vector.reduce.fadd.f32.nxv4f32(float -0.0, <vscale x 4 x float> %b)
  %r = fadd fast float %r1, %r2
  ret float %r
}
 
declare float @llvm.vector.reduce.fadd.f32.nxv8f32(float, <vscale x 8 x float>)
declare float @llvm.vector.reduce.fadd.f32.nxv4f32(float, <vscale x 4 x float>)

In MLIR, the Vector dialect already implements vector.vscale and supports scalable vectors. The example below demonstrates fixed-width vs scalable vector splat.

// Extracted from mlir/test/Target/LLVMIR/llvmir.mlir
 llvm.func @vector_splat_1d() -> vector<4xf32> {
   %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf32>) : vector<4xf32>
   llvm.return %0 : vector<4xf32>
 }
 
 llvm.func @vector_splat_1d_scalable() -> vector<[4]xf32> {
   %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
   llvm.return %0 : vector<[4]xf32>
 }

The key difference is vector<4xf32> vs vector<[4]xf32. Square brackets in the vector size, [4], is the MLIR syntax for vscale, i.e. [4] == vscale x 4.

Finally, below is an example of a regular scf.for loop over scalable vectors:

func.func @vector_scalable_memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %vs = vector.vscale
  %step = arith.muli %c4, %vs : index

  // %step is a multiple of `vscale`
  scf.for %i0 = %c0 to %size step %step {
    %0 = vector.load %src[%i0] : memref<?xf32>, vector<[4]xf32>
    vector.store %0, %dst[%i0] : memref<?xf32>, vector<[4]xf32>
  }

  return
}

MLIR can already lower this to valid SVE code utilising scalable vectors (example extracted from vector-scalable-memcpy.mlir).

Proposal

We have split this proposal into 4 components - each section below corresponds to one of the key elements of scalable vectorisation. We present what is already available, what is missing and what are the proposed changes.

We will use the following example to demonstrate what we are aiming for.

// example.mlir
#map = affine_map<(d0) -> (d0)>

module {
  func.func @example(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
    %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) outs(%arg2 : tensor<?xf32>) {
    ^bb0(%in_1: f32, %in_2: f32, %out: f32):
      %1 = arith.addf %in_1, %in_2 : f32
      %2 = arith.mulf %arg3, %1 : f32
      linalg.yield %2 : f32
    } -> tensor<?xf32>
    return %0 : tensor<?xf32>
  }
}

In this example we are using tensors, but the whole discussion would apply to memrefs as well. We will use dynamic shapes for inputs/outputs as the most suitable representation for what will later be lowered to scalable vectors.

1. Scalable Tiling

While Linalg’s vectoriser can be used without tiling, in practice one would apply tiling so that the resulting Linalg Op operates on tensors/buffers closely matching native vector sizes.

MLIR command line:

$ mlir-opt --test-transform-dialect-interpreter='transform-file-name=tile.mlir' -cse example.mlir -o example_after_tiling.mlir

Transform dialect sequence:

// tile.mlir
transform.sequence failures(propagate) {
  ^bb0(%arg1: !pdl.operation):
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
    %1, %loop = transform.structured.tile %0 [4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
}

IR after tiling:

// example_after_tiling.mlir
#map = affine_map<(d0)[s0] -> (4, -d0 + s0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @example(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
    %c4 = arith.constant 4 : index
    %0 = scf.for %arg4 = %c0 to %dim step %c4 iter_args(%arg5 = %arg2) -> (tensor<?xf32>) {
      %1 = affine.min #map(%arg4)[%dim]
      %extracted_slice = tensor.extract_slice %arg0[%arg4] [%1] [1] : tensor<?xf32> to tensor<?xf32>
      %extracted_slice_0 = tensor.extract_slice %arg1[%arg4] [%1] [1] : tensor<?xf32> to tensor<?xf32>
      %extracted_slice_1 = tensor.extract_slice %arg5[%arg4] [%1] [1] : tensor<?xf32> to tensor<?xf32>
      %2 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xf32>, tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) {
      ^bb0(%in: f32, %in_2: f32, %out: f32):
        %3 = arith.addf %in, %in_2 : f32
        %4 = arith.mulf %arg3, %3 : f32
        linalg.yield %4 : f32
      } -> tensor<?xf32>
      %inserted_slice = tensor.insert_slice %2 into %arg5[%arg4] [%1] [1] : tensor<?xf32> into tensor<?xf32>
      scf.yield %inserted_slice : tensor<?xf32>
    }
    return %0 : tensor<?xf32>
  }
}

As expected, the step in the generated loop is fixed-width:

%c4 = arith.constant 4 : index
%c0_4 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0_4 to %dim step %c4 

It is actually possible to generate a scalable loops by tweaking tileSizeComputationFunction. The following output was generated by updating the current implementation in LinalgTransformOps.cpp.

IR after scalable tiling:

#map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @example(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
    %c4 = arith.constant 4 : index
    %0 = vector.vscale
    %1 = arith.muli %c4, %0 : index
    %2 = scf.for %arg4 = %c0 to %dim step %1 iter_args(%arg5 = %arg2) -> (tensor<?xf32>) {
      %3 = affine.min #map(%arg4)[%1, %dim]
      %extracted_slice = tensor.extract_slice %arg0[%arg4] [%3] [1] : tensor<?xf32> to tensor<?xf32>
      %extracted_slice_0 = tensor.extract_slice %arg1[%arg4] [%3] [1] : tensor<?xf32> to tensor<?xf32>
      %extracted_slice_1 = tensor.extract_slice %arg5[%arg4] [%3] [1] : tensor<?xf32> to tensor<?xf32>
      %4 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xf32>, tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) {
      ^bb0(%in: f32, %in_2: f32, %out: f32):
        %5 = arith.addf %in, %in_2 : f32
        %6 = arith.mulf %arg3, %5 : f32
        linalg.yield %6 : f32
      } -> tensor<?xf32>
      %inserted_slice = tensor.insert_slice %4 into %arg5[%arg4] [%3] [1] : tensor<?xf32> into tensor<?xf32>
      scf.yield %inserted_slice : tensor<?xf32>
    }
    return %2 : tensor<?xf32>
  }
}

This time the step in the generated loop is scalable:

    %0 = vector.vscale
    %1 = arith.muli %c4, %0 : index
    %2 = scf.for %arg4 = %c0 to %dim step %1

This is pretty much what we need in order to proceed. Since we were able to generate that with minimal changes, we believe that most of the support is already there. Modulo some plumbing (yet to be implemented upstream).

The experiment above required a few edits in “LinalgTransformOps.cpp”. Instead, we would like to drive “scalable tiling” with the transform dialect. At the moment, there is no syntax for that. We propose the following notation:

  • transform.structured.tile %0 [[4]] to tile one dimension using a scalable vector size with base four.
  • transform.structured.tile %0 [2, [4]]] to tile two dimensions using fixed-length vector size two for the first dimension and a scalable vector size with base four for the second dimension.

Proposed Changes

  1. Update the logic in LinalgTransformOps.cpp so that it will generate scalable tiles when requested.
  2. Add syntax and support to the Transform dialect that would allow to specify scalable tile sizes.
  3. Stress test this functionality (e.g. by contributing tests upstream).

2. Scalable Vector Representation

Scalable vectorisation heavily relies on vector masking to iterate over the iteration space with a vector length that is unknown at compile time. Although Vector masking is fully supported in the Vector dialect, some of the main masking constructs still do not support scalable vectors. The following table summarises the existing support:

Operation Scalable Vector Support
vector.create_mask Yes
vector.constant_mask Yes
vector.transfer_read No
vector.transfer_write No
vector.gather Yes (lowered directly to llvm.intr.masked.gather)
vector.scatter Yes (lowered directly to llvm.intr.masked.scatter)
vector.mask No
vector.maskedload Yes
vector.maskedstore Yes
vector.maskedstore Yes

Most of the existing support in these operations comes from the extraordinary work that the community has done in the Sparse compiler vectoriser. However, the Sparse vectoriser operates at a lower level of abstraction where operations like vector.mask and vector.transfer_* ops are not used. These operations are fundamental to masked vectorisation in the Linalg vectoriser and will require the proper scalable vector support.

Proposed Changes

  • Add support for scalable vectors to vector.mask and vector.transfer_* ops.

3. Linalg Vectoriser

After scalable tiling and with the proper scalable masking support in place, we will be ready to apply scalable vectorisation using the Linalg vectoriser. The following snippets show the proposed process to apply scalable vectorisation to our running example using the Transform dialect.

MLIR command line:

$ mlir-opt --test-transform-dialect-interpreter='transform-file-name=vectorise.mlir' -cse example.mlir -o example_after_tiling.mlir

Transform dialect sequence (note the scalable [4] vector size):

// vectorize.mlir
transform.sequence  failures(propagate) {
^bb0(%arg0: !pdl.operation):
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!pdl.operation) -> !pdl.operation
    transform.structured.masked_vectorize %0 vector_sizes [[4]]]
}

IR after scalable vectorisation:

#map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
module {
  func.func @example(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
    %c4 = arith.constant 4 : index
    %0 = vector.vscale
    %1 = arith.muli %c4, %0 : index
    %2 = scf.for %arg4 = %c0 to %dim step %1 iter_args(%arg5 = %arg2) -> (tensor<?xf32>) {
      %3 = affine.min #map(%arg4)[%1, %dim]
      %extracted_slice = tensor.extract_slice %arg0[%arg4] [%3] [1] : tensor<?xf32> to tensor<?xf32>
      %extracted_slice_0 = tensor.extract_slice %arg1[%arg4] [%3] [1] : tensor<?xf32> to tensor<?xf32>
      %extracted_slice_1 = tensor.extract_slice %arg5[%arg4] [%3] [1] : tensor<?xf32> to tensor<?xf32>
      %dim_2 = tensor.dim %extracted_slice, %c0 : tensor<?xf32>
      %cst = arith.constant 0.000000e+00 : f32
      %4 = vector.create_mask %dim_2 : vector<[4]xi1>
      %5 = vector.mask %4 { vector.transfer_read %extracted_slice[%c0], %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
      %6 = vector.mask %4 { vector.transfer_read %extracted_slice_0[%c0], %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
      %7 = arith.addf %5, %6 : vector<[4]xf32>
      %8 = vector.broadcast %arg3 : f32 to vector<[4]xf32>
      %9 = arith.mulf %8, %7 : vector<[4]xf32>
      %10 = vector.mask %4 { vector.transfer_write %9, %extracted_slice_1[%c0] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
      %inserted_slice = tensor.insert_slice %10 into %arg5[%arg4] [%3] [1] : tensor<?xf32> into tensor<?xf32>
      scf.yield %inserted_slice : tensor<?xf32>
    }
    return %2 : tensor<?xf32>
  }
}

As discussed for scalable tiling, we have no mechanism to express scalable vectors in the Transform dialect vectorisation sequence. We propose to use the same notation as for scalable vector types:

  • transform.structured.masked_vectorize %0 vector_sizes [[4]]] will vectorise one dimension using a scalable vector size with base four.
  • transform.structured.masked_vectorize %0 vector_sizes [2, [4]]]: will vectorise two dimensions using fixed-length vector size two for the first dimension and a scalable vector size with base four for the second dimension.

The Linalg vectoriser would also have to be extended to vectorise operations using scalable vectors and to generate scalable vector masks. This includes changes ranging from the vectoriser API to the core vectorisation algorithm.

Scalable Vectorisation Strategies: Predicated vs Unpredicated Main Vector Loop

Similar to fix-length vectorisation, scalable vectorisation may benefit from different strategies when vectorising a number of iterations that is not multiple of the physical vector length. We plan to focus on the most widely applied two: predicated/masked main vector loop without remainder loop and unpredicated/unmasked main vector loop with remainder loop.

Predicating/masking the main vector loop seems like the most canonical approach for the existing scalable vector architectures and that is what we would like to focus on first. Supporting this approach only requires basic vector masking functionality using scalable masks, which is part of our plan already. However, generating an unpredicated/unmasked main vector loop with remainder loop may perform better in practice on some architectures so we would also like to explore this strategy. For example, for Arm’s Neoverse V1 [1], whilelt (“predicated loop” instruction) has lower throughput than cmp. In order to generate a remainder loop for scalable vectorisation we only have to apply peeling to the input iteration space using the target scalable vector sizes.

[1] Arm® Neoverse™ V1 Software Optimization Guide

Proposed Changes

  1. Extend Linalg vectoriser API to support input scalable vector sizes.
  2. Compute and propagate vector sizes to scalar operations from the input vector sizes.
  3. Extend vector masking to support scalable masks.
  4. Make sure that the peeling transformation works with scalable vectors (to support unpredicated vectorisation).
  5. Add syntax and support to the Transform dialect that would allow to specify scalable vector sizes.

4. Missing Vector Lowerings/Canonicalisations

We are a bit unsure about the specifics here, but in general expect to discover a lot of missing lowerings and/or canonicalisations for scalable vectors. There are also bound to be cases where we need to introduce (or at least test) missing lowerings to LLVM IR.

For example, there’s already a bunch of end-to-end integration tests in MLIR that demonstrate how to to generate valid Arm SVE/scalable code from MLIR, see e.g. sparse tensor integration tests. However, we haven’t really tried predicated SVE loops yet.

Proposed Changes

  1. Introduce lowerings for predicated loops and any other missing lowerings.
  2. Extend all relevant canonicalisation patterns (e.g. for vector transfer/masking Ops) to support scalable vectors.
  3. Stress test new and existing lowerings/canonicalisation for scalable vectors.

Next Steps

We plan to contribute the changes proposed above to MLIR upstream. They will be implemented and tested in isolation. Following that, we will use IREE ARM CPU backend to integrate the different components and perform end-to-end testing. Once that’s available, we plan to focus on:

  • Evaluating SVE vectorisation using real hardware with wide vectors (e.g. 256 bits),
  • Based on those experiments, work on fine-tuning the performance.

Thank you for taking a look! Any feedback is more than welcome!

Andrzej and Diego

Appendix

For folks less familiar with scalable vectors, here are two major architecture extension for Arm architecture which are “scalable”:

8 Likes

@javiersetoain, @nicolasvasilache, @ThomasRaoux, @MaheshRavishankar, @aartbik, @zhanghb97

1 Like

Hi Andrzej and Diego,

Thanks for this RFC! This will help the linalg operation to exploit the VLA feature of the scalable vector backends.

Both Arm SVE and RISC-V Vector (RVV) extensions use the VLA strategy.
I conducted some simple experiments and found that although the vscale values on both sides are different for the same VLEN, the generality of the proposed method is unaffected. (RVV will use the vector group to handle longer steps)

Considering the VLA strategy, I think tail processing would be an interesting topic. The currently proposed method seems to utilize the “predicated/masked main vector loop without remainder loop,” which will work on both Arm SVE and RVV backends. Apart from that, there can be another predication (or strip-mining) method on RVV side, and the model will be like this:

// While loop for strip-mining.
%tmpAVL, %tmpIdx = scf.while (%avl = %dim, %idx = %c0) : (index, index) -> (index, index) {
  // If avl greater than zero.
  %cond = arith.cmpi sgt, %avl, %c0 : index
  // Pass avl, idx to the after region.
  scf.condition(%cond) %avl, %idx : index, index
} do {
^bb0(%avl : index, %idx : index):
  %vl = rvv.setvl ... ... // A set vl operation
  // Perform the calculation according to the vl.
  ... ...
  // Update idx and avl.
  %new_idx = arith.addi %idx, %vl : index
  %new_avl = arith.subi %avl, %vl : index
  scf.yield %new_avl, %new_idx : index, index
}

I’m working on the proposal to add presentations upstream to make this work. I also want to explore which vectorization or tail processing strategies will work better. I would love to have a more detailed discussion!

Thanks,
Hongbin

2 Likes

Nice proposal, this is perfectly in-line with the design principles of Linalg, thanks for pushing in this direction!

Great choice!

We have tried another design choice for similar injection of dynamic SSA values that we have found to be more general and may be worth considering here.

Most linalg transform operations are parametric and support SSA values (or should be extended accordingly); you could create new transform operation(s) that returns a ValueHandle:

%0 = transform.vector.vscale %top_level_op_where_to_create_the_value
%1 = transform.compute_scalable_tile_size %0
...

and then just pass the resulting SSA value to the existing transform.structured.tile.

This should mostly work out of the box and generalize to other use cases.
In particular, one wouldn’t need to update the various tiling transform ops (to parallel loops, to reduction with/without padding etc).

Open questions we had touched on with @ftynse in the past (I forget were we are these days) are:

  1. do we want a generic set of arithmetic / affine operations to manipulate and construct such quantities or do we want one-off named transform ops ? I am thinking in particular about tile size selection logic that could use transform macros and avoid a proliferation of C++
  2. add effects to avoid DCE while such transient SSA values are in flight but not yet used.

+1

The solution I outlined above should almost work out of the box here too, the vector_sizes operand will need to be extended to support SSA values and the verifier adapted accordingly but by now the base infra should exist to make this transition easy (list of mixed attribute / SSA values is well supported).

+1 this is often where a lot of missing components appear and often represents quite some work.
On the positive side, this is generally easy to split up and distribute amongst contributors.

2 Likes

Thanks for the feedback, Honbin and Nicolas!

Yes, that’s the next step in the vector length fixed-length → scalable-length → dynamic-length. Really excited to see all of this moving forward!

The problem I see is that a scalable vector (e.g., vector<[4]xf32>, where [4] actually denotes vscale * 4) has two parts: the scalable factor (i.e., vscale, dynamic) and the base factor (i.e., 4, static). We need to pass the static part as a static value to Linalg vectorizer as it will become part of the vector type.

Good point, masked_vectorize is enough of a special flower that it makes sense to evolve the transform syntax. As long as the various tilings use SSA values this sgtm.

I would still push back on this. What Nicolas described composes better (we already have transform ops that emit custom tile size computations, e.g. structured.multitile_sizes) with existing ops and better separates concerns (one doesn’t need to understand and care about the special syntax for scalable vectors while tiling in cases where scalable vectors are not a thing).

We don’t necessarily want a transform op that emits vector.vscale in the payload IR and returns an opaque handle in it, but we certainly can have a transform op that has the fixed based factor as attribute and produces derived quantities accounting for that.

%tile_size = transform.vector.compute_scalable_tile_sizes { base = 4 } (%op)
%tiled = transform.structured.tile %op [%tile_size]
transform.structured.masked_vectorize %tiled [[4]]

The “4” is repeated, but it also would have been repeated if it was explicitly listed in tile sizes. FWIW, this is better off as a parameter, which is “a handle to attribute” anyway so will have zero problems being used in a type:

%vscale_base = transform.param.const 4 : i64 -> index
%tile_size = transform.vector.compute_scalable_tile_sizes base(%vscale_base) (%op)
%tiled = transform.structured.tile %op[%tile_size]
transform.structured.masked_vectorize %tiled [[%vscale_base]]

which would both remove the footgun of accidentally using different constants in two places and make masked_vectorize less special.

Bonus point, taking this approach will give you the “unmasked main loop + masked memainder loop” approach for almost free:

%vscale_base = transform.param.const 4 : i64 -> index
%tile_size, %split_point = 
  transform.vector.compute_scalable_tile_size_and_split_point
  base(%vscale_base) (%op)
%main, %remainder = transform.structured.split %op after %split_point
%main_t = transform.structured.tile %main [%tile_size]
transform.structured.vectorize %main_t
transform.structured.masked_vectorize %remainder [[%vscale_base]]

all operations except for size computation already exist, though vectorizations should become more targetable than they currently are.

I’d prefer not writing a generic interpreter, but that’s a separate discussion.

This is no longer a problem since even IREE doesn’t indiscriminately run DCE after every transform anymore.

2 Likes

Thanks, Alex! That sounds great to me! No strong opinion about the approach to take here. My main concern was about passing a random SSA value to masked_vectorize since that would require some kind of validation and folding to a scalable vector type. What you suggest addresses that concern so it works for me! WDYT, @banach-space?

Thank you all for your feedback, that’s much appreciated!

Apologies for going a bit radio silent. @zhanghb97 and I (as well as other folks) chatted about this proposal last week at EuroLLVM. Most of the discussion points have already been captured here by others, so I won’t be adding much.

The approach proposed by Alex makes a lot of sense to me and is very neat. I really like the example with “unmasked main loop + masked remainder loop”. It might take a while before we get there, but we can work towards that incrementally.

In general, for the loop remainder we would begin with:

  • predicated/masked main vector loop without remainder,
  • unpredicated/unmasked main vector loop with remainder loop (the later could technically be either scalar or masked).

The other approach suggested by Hongbin would require vectors with “dynamic length”. As that’s still work-in-progress, we should probably keep it for later discussion.

Thank you,
-Andrzej

First patch to enable scalable tiling:

I’m sending it for some early feedback - just to make sure that I’m taking the right approach.

Please be mindful that there might be a few rough edges to polish :wink: (all feedback is very welcome!)

Thank you,
-Andrzej

Just a small status update. The following patches have been merged:

Those two enable the following syntax for expressing:

  • scalable tile sizes:
transform.structured.masked_vectorize %0 vector_sizes [[4]]
  • scalable vector sizes:
transform.structured.tile %0 [[4]]

ATM, only the trailing size can be scalable. I will be extending that in the coming weeks.

-Andrzej

2 Likes

Initial scalable vectorization support in Linalg: ⚙ D152599 [mlir][Vector] Add basic scalable vectorization support to Linalg vectorizer
Only elementwise are supported for now.

2 Likes