[RFC] Hardening the VectorType API

We would like to add some lightweight wrappers for VectorType (accessors and iterators), along with some new builders, to make it easier and safer to work with scalable dimensions.

Motivating example

Let’s say that you are trying to create a new VectorType by “dropping” the leading n dimensions of vType:

VectorType::get(vType.getShape().drop_front(n),
                vType.getElementType(),
                vType.getScalableDims().drop_front(n));

This is problematic for a few reasons:

  • Shape and scalability flags have to be updated separately
  • It exposes an implementation detail that is often irrelevant (as in the example above)
  • It is easy to forget about the scalability flags entirely

Another problematic example is when you check the size of a dimension:

  • Usual thing - check for 1 (fixed width “1”):
    • vType.getDimSize(0) == 1 && !vType.getScalableDims()[0]
    • The code is quite verbose - checks for “scalability” even though we’re looking for plain fixed width “1”
  • Unusual thing - check for 1 and [1] (fixed width and scalable “1”):
    • vType.getDimSize(0) == 1
    • Seems normal, but matches [1] and 1
    • The semantics of 1 and [1] are different and this is rarely what’s intended

Proposal

We would like to add new scalability-safe accessors and iterators to VectorType, along with new builders that can make use of these:

/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims();

With the proposed changes, the examples above would look like this:

// Drop leading n dims:
VectorType::get(vType.getElementType(), vType.getDims().dropFront(n));

// Check if the leading dim is a _fixed_ unit dim:
vType.getDims()[0] == VectorDim::getFixed(1);

In both cases:

  • The intent is clear
  • The code is more concise
  • Scalability is checked and preserved (enforced by the new APIs)

The new accessors are backed by two new classes:

  • VectorDim represents a single dimension of VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons

  • VectorDims represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef)

Importantly, this proposal does not change the storage of VectorType.

Discussion

In our previous attempt to improve how scalable dimensions are treated, people suggested restricting the changes to VectorType.

One option here was separating VectorType from ShapedType (switching solely to the new APIs), however:

  • That would be a large non-incremental change
  • Likely to cause major issues downstream

So as a compromise, this proposal only adds new APIs, that make:

  • The task of improving support for scalable vectors much easier
  • The codebase more concise and succinct

Also, those less interested in “scalability” won’t be exposed to the implementation details anymore. There is one disadvantage of this approach:

  • It creates an alternative/additional mechanism to deal with dimensions of VectorType

While we are happy to refactor the codebase to keep it uniform, there won’t be a mechanism to prevent people from using the accessors that we have today and to mix both approaches.

Feedback

Your feedback is much appreciate :pray: A complete implementation is available here:

To see the potential impact that this will have on the code-base:

Thank you for taking a look :slight_smile:
Andrzej & Ben

2 Likes

cc @c-rhodes, @banach-space, @dcaballe

The issue with exotic additions is that it breaks the uniformity of the lowest-level APIs and each type becomes its own special flower.

Why not improve llvm-project/mlir/include/mlir/IR/BuiltinTypes.h at cd138fddf1b52a43108376371ad1c38585aaa4e2 · llvm/llvm-project · GitHub as needed rather than exotic additions?

VectorType::get(vType.getShape().drop_front(n),
                vType.getElementType(),
                vType.getScalableDims().drop_front(n));

becomes:

VectorType::Builder(vType).dropFront(n);

Because it’s not just that specific case that we want to fix. That’s an example of one thing, but we’d like something more general.

  • We’d like to manipulate and iterate over vector dimensions
  • We’d like to safely match and compare vector dimensions
  • We like to pass around references to vector dimensions
  • Basically, anything you could do if a vector dimension was a first-class concept

Note that this is something that already exists in LLVM proper, with LLVM: llvm::ElementCount Class Reference used to represent fixed or scalable quantities.

The problem with the current APIs is they’re pretending VectorType is not a special flower, which unfortunately, it is because scalability is not a concept ShapedType supports.

This is the problem with scalable vector: and I’ve always been concerned with it (back to when it was proposed to LLVM) because of what you’re describing here. This breaks uniformity at the lowest level.
But that’s almost intentional (even though scalable vector came with better promise from this point of view IIRC). This ship has sailed and vector is no longer a “ShapedType”, the “shape of a vector” isn’t something you can reason about in any “normal” way: it is its own flower as soon as you accept scalable vector. The similarity is only one of surface…

1 Like

Yes, we can do that, but there are other issues for which there’s no easy fix. This is one of then:

vType.getDimSize(0) == 1

To me this is ambiguous. Is it meant to match:

  • fixed width “1”, or
  • fixed width and scalable “1”?

Most of the time we mean the former (the semantics of “1” and “[1]” are very different), but that should be written like this instead:

vType.getDimSize(0) == 1 && !vType.getScalableDims()[0]

That’s something that’s easy to miss and is rather counter-intuitive - why would one be checking the scalability flags when searching for fixed width “1”?

I think that the wrappers proposed here would also help with discussions like this one (CC @qed):

Indeed, I think that that’s key here. The point of this RFC is to make the APIs more explicit about this and to minimise any potential surprises.

2 Likes

Thanks, it seems I over indexed on the first motivating example.
The proposal makes sense to me.

1 Like

For me, if mapped to ShapedType (which is just a shape and element type, and the shape is “simple”) this isn’t ambiguous. For [1] I’d argue this should return ? (unknown/dynamic) rather than 1 when querying the dim size. If doing a compile time query, it should not return a static value unless static. Here it isn’t static.

Now you’ve overlayed other information on top, and so for this derived type, you have a different question you can ask as to if it has a static base index, which would be 1 here.

In general, I agree with you, but vscale (i.e. “scalability”) does carry some extra meanings (and puts constraints on the semantics of vector dims) and we’d like to make that explicit.

Ha! But we are in the “Vector” land and we don’t really have “unknown/dynamic” and that’s one big difference here - we are using ShapedType API for something that has some additional meaning unavailable at the ShapedType level.

This a bit more nuanced though and brings us back to the meaning of vscale. Yes, it’s not known at compile time, but we do have an SSA value that we can use to represent it. So it’s not really “unknown” - there are multiple assumption that we can make about it and we do.

And just to clarify - we do have all the APIs that we need for supporting scalable vectors available today. This RFC is an attempt to make things more explicit and to avoid confusion. In particular, ? != [1].

-Andrzej

I still think this RFC is something that makes correctly handling scalability and updating existing code to work with scalability much more easier.

For example, I recently needed to get PolynomialApproximation working for scalable vectors. Currently, all the rewrites work by passing around the vector shape as a ArrayRef<int64_t>, which drops all the scalable dimesions. Getting things working for scalable vectors with the VectorDims from the RFC was pretty much just a find+replace of ArrayRef<int64_t>VectorDims, which then transparently preserves scalability.

Current code:

 // Returns vector shape if the type is a vector. Returns an empty shape if it is
 // not a vector.
 static ArrayRef<int64_t> vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
   return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
 }
  
 static ArrayRef<int64_t> vectorShape(Value value) {
   return vectorShape(value.getType());
 }

 // Broadcasts scalar type into vector type (iff shape is non-scalar).
 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
   return !shape.empty() ? VectorType::get(shape, type) : type;
 }
  
 // Broadcasts scalar value into vector (iff shape is non-scalar).
 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                        ArrayRef<int64_t> shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
   return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
 }

Updated code:

// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
static VectorDims vectorShape(Type type) {
  auto vectorType = dyn_cast<VectorType>(type);
  return vectorType ? vectorType.getDims() : VectorDims();
}

static VectorDims vectorShape(Value value) {
  return vectorShape(value.getType());
}

// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, VectorDims shape) {
  assert(!isa<VectorType>(type) && "must be scalar type");
  return !shape.empty() ? VectorType::get(type, shape) : type;
}

// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                       VectorDims shape) {
  assert(!isa<VectorType>(value.getType()) && "must be scalar value");
  auto type = broadcast(value.getType(), shape);
  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
}