[RFC] Packing for sub-byte types

Summary

We propose to add support for sub-byte types (e.g., i4, i3, i2) packed into byte-aligned values. This will be implemented through a new dialect, packed, that defines (parametric) packed types, e.g., packed.bits<8:2xi4>, and a small number of operations over these types; initially focusing on packing, unpacking, casting, insertion/extraction for memory interaction.

Background

Research has been pushing on sub-byte quantization for quite some time. Now with the recent advances of large language models (LLMs), we see an even greater need to support quantization/weight compression through small integer types, e.g., int4, int3 in GPTQ. These can help to reduce the size of LLMs and make them more efficient to run. Especially, on memory constrained platforms like mobile/edge devices, supporting i4/i3/i2 types determines whether we can run these models at all.

The difficulty that comes with these sub-byte types is that memory is typically byte-addressable and that the existing hardware does not expose instructions to manipulate those types. In practice, this means that we have to pack such integer types into larger types supported by the platform of choice. For example, we can densely pack two i4 values in an i8 for storage, and perform calculations by first unpacking each sub-value into i8 and later packing both results back to a single i8 with some masking/shifting. While this solves the problem of storage, it does not lead to the best arithmetic performance without a dedicated hardware compute unit. So there exist more packing schemes that give better arithmetic performance, e.g., ULPPack, but come at an increased complexity of not being densely packed, e.g., three i3s in a single i16.

Proposal

We propose to solve the problem of sub-byte types by introducing a new parametric type containing multiple sub-values: packed.bits<N-Bits:KxIntOrFloat:[PackingScheme]>. The key property is that such packed types are byte-addressable; they hold necessary information in an integrated way, and do not require premature unpacking or sub-byte calculations until deemed suitable. packed.bits has elementwise semantics. Unlike these shaped types, all components together form a single ‘logical’ value / register; for example, inserting a poisoned sub-value would poison the whole packed value.

Ops:

  • packed.pack – creates a packed value out legal/wide scalar components
  • packed.unpack – extracts all sub-values as legal/wide scalar components
  • packed.cast – converts between packed values and their storage representation
  • packed.insert – returns a new packed value with a single sub-value inserted
  • packed.extract – extracts a single sub-value as a legal/wide scalar

Examples:

// Types:

// packed.bits<8:2xi4:[0,4]>      // Explicit form, two sub-values at bit-offsets 0 and 4.
// packed.bits<8:2xi4>            // Same as above, dense packing is the default.
// packed.bits<16:3xi4:[0,8,12]>  // Sub-values start at bit-offsets 0, 8, and 12. 
// packed.bits<32:2xf16>          // Also floats but not index.

// vector<4xpacked.bits<8:2xi4>>
// tensor<1920x1080x3xpacked.bits<8:4xi2>>
// memref<?x?xpacked.bits<8:2xi4>>

// Ops:
func.func scalar(%arg: packed.bits<8:2xi4>) -> () {
	%a   = packed.unpack %arg packed.bits<8:2xi4> to vector<2xi8>
	%b:2 = packed.unpack %arg packed.bits<8:2xi4> to (i8, i8)

	%c = packed.pack %a : vectro<2xi8> to packed.bits<8:2xi4>
	%d = packed.pack %b#0, %b#1 : i8 to packed.bits<8:2xi4>

	%e = packed.cast %c : packed.bits<8:2xi4> to i8
	%f = packed.cast %e : i8 to packed.bits<8:2xi4>

	%x = packed.insert %y#0 into %d[0] : i8, packed.bits<8:2xi4>
	%y = packed.extract %x[1] : packed.bits<8:2xi4>, i8
}

func.func elementtype(%arg: memref<?x?xpacked.bits<8:2xi4>) -> () {
     %0 = vector.transfer_read %arg ... : memref<?x?xpacked.bits<8:2xi4>>,
                                          vector<4xpacked.bits<8xi4>>
     %1 = vector.transfer_read %arg ... : memref<?x?xpacked.bits<8:2xi4>>,
                                          vector<4xpacked.bits<8xi4>>

	%a   = packed.unpack %0 vector<4xpacked.bits<8:2xi4>> to vector<4x2xi8>
	%b:2 = packed.unpack %1 vector<4xpacked.bits<8:2xi4>> to (vector<4xi8>, vector<4xi8>)

	%c = packed.pack %b#0, %b#1 : vector<4xi8> to vector<4xpacked.bits<8:2xi4>>

	%d = packed.cast %c : vector<4xpacked.bits<8:2xi4>> to vector<4xi8>
	%e = packed.cast %d : vector<4xi8> to packed.bits<8:2xi4>

	%x = packed.insert %b#0 into %c[0] : vector<4xi8>, vector<4xpacked.bits<8:2xi4>>
	%y = packed.extract %x[1] : vector<4xpacked.bits<8:2xi4>>, vector<4xi8>
}

These packed operations can be lowered to arithmetic over the storage type through masking/shifting. We do not propose packed constants, as these can be constructed through packed.pack or packed.cast of integer constants.

In the future, we may envision more operations like packed.repack to convert between different packed representations, and dedicated ops for efficient arithmetic operations over packed formats, e.g., pack.dotprod, packed.reduce. Those can be added when the needs arise.

Interactions with Other Dialects

This dialect is meant to be at a root level of the dialect tree – other higher level dialects can depend on it. We anticipate that packed.bits types compose nicely with vector, tensor, and memref types by being an allowed element type there. As a by-product, this gives us ways to improve memref’s memory storage story regarding sub-byte types.

Initially, we will support computation with regular arith/math operations over the storage type by “entering” and “exiting” the packed.bits world via dedicated construction/destruction ops: packed.pack and packed.unpack. When chained, the resulting pack/unpack operations should cancel out.This enables tightly controlled lowering procedures.

In the future, we could define the semantics for elementwise arith/math operations over packed types via elementwise ops over the sub-value types.

Alternatives Considered

The two other most natural representations are using the storage type directly (e.g., i8) or vectors or sub-byte types (e.g., vector<...x2xi4> or vector<...xvector<2xi4>>). These two approaches come with the following drawbacks:

  • Performing arithmetic over integer types leads to difficult to read and optimize code. We lose the high-level information about the actual element types and their value bounds. In addition, the experience with the wide integer emulation pass (the opposite direction) suggests that there will be a lot of early bloat in the constants and arithmetic operations for both packing/unpacking and controlling overflows.
  • Neither native integers nor vectors offer a solution for non-dense packing schemes as the memory accesses would not be byte-aligned.
  • Even for densely-packed vectors, we would be adding to the already overloaded vector semantics: there would be no way to differentiate normal dimensions from dimensions used for bit packing. At the lowest level of lowering, vectors are typically one-dimensional and should correspond to individual hardware registers / vector lanes. Adding an extra dimension or level of nesting breaks such lowering paths, specifically, dimension mapping during tiling and vectorization.

Looking forward to your feedback.

-Jakub (@kuhar) and Lei (@antiagainst)

This seems quite unclear to me in terms of semantics: are the arith operations processing this as a single i8 or as a vector of 2xi4?

I’d like to understand this a bit better. For the first argument here, it seems to me that these types shouldn’t allow directly any arithmetic (this comes back to your intent with your question above).
That is I would rather see an i8 converted to a vector<2xi4> before being processed with arithmetic operations.

I view this situation in a similar way as the LLVM move towards untyped pointers: we can deal with the memory storage using i8 and manage types in SSA values / registers. That is your proposal does not make a clear case to me against using memref<?xi8> and something like memref.vload %ptr[%idx] : memref<?xi8> -> vector<2xi4>.

Your proposal being about “sub-byte types (e.g., i4, i3, i2) packed into byte-aligned values”, it’s not clear to me how you address “non-dense packing schemes” nor “non byte-aligned” memory: can you elaborate here?

I don’t understand what you mean, can you elaborate with examples? I feel I’d need to see quite concrete examples to understand why using i8 and vectors wouldn’t do the job just as well here.

Ah, I forgot to remove this from the examples. Initially this won’t be supported – you will have to first unpack, do the arithmetic, and then pack again.

In a future iteration, I think it should be possible to give packed.bits<...:KxiN> elementwise semantics similar to vector<KxiN> (modulo poison once that’s also flashed out). But for now that’s outside of this proposal.

The main improvement over arithmetic over byte-aligned types is that if we want to perform arithmetic over unpacked values, we can do that without polluting the IR with shifts/masks/constants that create a lot of noise. For some operations, we may not want to unpack the values at all if we can perform arithmetic efficiently over a packed representation.

Eventually these packed values will be lowered down to integer/vector arithmetic over supported types, but this can be deferred until much later in the pipeline when we know the target, instead of attempting abstraction raising to recover the intention behind the math. For example in SPIR-V, i32 and vectors up to 4xi32 are the only given, so we would prefer a sparse arithmetic over packed values. On other targets that support much more efficient dot product instructions that we would want to lower to that much later in the codegen pipeline.

In this specific case of i4 the situation is not too bad with memref<?xi8> as you can remember to align indices and load both halves at the same time. But for i3 it’s much less clear how to allow for different packing formats, e.g., packed.bits<64:21xi3:...>, packed.bits<32:10xi3:...> and maintain track of indexing, sub-value offsets, and buffer lengths.

I think @antiagainst had some more refined thoughts about this.

The type allows to define the non-dense packing through the third parameter, e.g.: packed.bits<16:3xi4:[0,8,12]> . The basic guarantee is that the type has the same alignment as the storage type (here 16 bits – 2 bytes) and the sub-value bit offsets are encoded in the type itself, so that we can pass them as memrefs without going into bit-addressable memory, fat pointers, etc.

One difficulty would be that by going to vector of i8 you lose track of the logical number of elements. Suppose that at some point you end up with a value vector<16x16xi8> that you have to transpose – I’m not sure how tell apart the case when the values are truly i8 and when it’s really vector<16x16xvector<2xi4>>.
When we lower to SPIR-V, we have to ensure that we end up with vectors of size at most 4 – this would force us to either lift this restriction and temporarily allow for other lengths until we make sure that we took advantage of all efficient ops that accept small bitwidths, or to give up arithmetic performance and expand each sub-value to i8/i16/i32 (as allowed by the environment). The alternative that packed.bits gives us is that we can consider it to be a scalar type and defer this decision.

I don’t follow: right now your proposal does not contain arithmetic over packed value right? So you’re not addressing this.
And if you add this to your proposal, I would ask again how does it compare to the vector dialect? That is what is the semantics difference between packed.addi %a, %b : packed.bits<2xi4> and vector.addi %a, %b : vector<2xi4>`?

It’s not clear to me why this is important at a high level: that is similar to me as the lack of precise memory layouts for tensor types in that it can be deferred to the lowering.
Alternatively, it can also be stored in the load itself: packed.vload %ptr[%idx] : memref<?xi8> → vector<2xi3> { packing = <0, 5>}` (where “packing” stores the bit indices at which the i3 are starting.

This isn’t what I’m talking about: you always have vector of the right element types, see the example I mentioned in my previous message about : memref.vload %ptr[%idx] : memref<?xi8> -> vector<2xi4> ; the memory does not carry the packing information, the loaded value does though.

I’ll try to illustrate the difference with an example:

  1. Manual packing/unpacking:
  %x = ... : i8
  %e0 = arith.andi %x, %mask : i8
  %e1 = arith.shrui %x, %cst4 : i8
  %v0 = vector.insert %e0 into %cst0[0] : vector<2xi8>
  %v1 = vector.insert %e1 into %v0[1] : vector<2xi8>
  // Do some arithmetic over %v1. The new value is %vv.
  %m = arith.andi %vv, %vmask : vector<2xi8> // limit to i4
  %r0 = vector.extract %m[0]
  %r1 = vector.extract %m[1]
  %r2 = arith.shli %r0, %cst4 : i8
  %r = arith.ori %r0, %r1 : i8   
  1. With packed.bits:
  %x = ... : packed.bits<8:2xi4>
  %v1 = packed.unpack %x : packed.bits<8:2xi4> to vector<2xi8>
  // Do some arithmetic over v1. The new value is %vv.
  %r = packed.pack %vv : vector<2xi8> to packed.bits<8:2xi4>

My point is that this form is much more concise because it does not involve expansion into this shifting/masking until we need it, and it makes it trivial to elide redundant pack/unpack pairs.

This would be much more verbose is the format was more complicated and contained more sub-values.

Arithmetic support is provided over unpacked values.

Let me adjust the example with something that look like a more fair comparison to me ; that’s how I would use vectors here:

  %ptr = ... : memref<?xi8>
  %v = memref.vload %ptr[%idx] : memref<?xi8> -> vector<2xi4>
  %v1 = vector.sext %v : vector<2xi4> -> vector<2xi8>
  // Do some arithmetic over %v1. The result is %r

vs

  %ptr = ... : memref<?x packed.bits<8:2xi4>>
  %v = memref.load %ptr[%idx] : memref<?x packed.bits<8:2xi4>> -> packed.bits<8:2xi4
  // You cannot do arithmetic over v, you need to convert to a vector somehow...
  %v1 = packed.unpack %x : packed.bits<8:2xi4> to vector<2xi8>
  // Do some arithmetic over v1. The new value is %vv, need a new conversion before storing.
  %r = packed.pack %vv : vector<2xi8> to packed.bits<8:2xi4>

WDYT?

I hope you don’t mind me butting in, but I also had this problem a while ago (I don’t have a proposal though).

First, I also support the notion that pack and unpack are implicitly performed by properly characterized store/load operations, as in Mehdi’s last example. This comes from my perspective of reconfigurable computing, where this is usually a responsibility pawned off onto the memory interface. And I don’t want to overconstrain that by requiring a wide aligned load followed by some bit-twiddling, for example.

So, for me, packed is actually an answer to a separate question, which is: “How to lower to a concrete byte-aligned memory layout for sub-byte types.” And that’d be only one way to achieve the results you were talking about. I’m not sure whether packed would also serve as the “abstract interface” in the “non-aligned store/load” scenario, if that makes any sense to you.

However, the above also includes the question of padding, e.g., i7 -zext-> i8 instead of densely packed i56 bits. I was pressed more by this issue ATM, so that’s what I mostly thought about. In that scenario, explicit add/remove of the padding bits and a padded wrapper type made the most sense. Partly because this allows a lowering pass to rewrite this decision transparently later, which would be very hard to do correctly with the store/loads.

Thanks for suggesting alternatives, @mehdi_amini and @KFAF.

I know that @antiagainst has some more nuanced arguments here, but for me, I don’t see how support at the level of loads solves all the issues I outlined. In a codegen pipeline, I imagine that memrefs would appear late, and at the very least we would want to be able to pack/unpack from tensors, vectors, and scalar function arguments or loop variables. Following your suggestions, I think we would have to stay at the level of sub-byte types in tensor/vectors until bufferization.

I wonder if we could have something in between these two extremes, e.g., leave the memory loosely typed at the level of byte-aligned element types and have pack/unpack as vector ops, similar to vector.gather/vector.scatter, that would support accessing vectors, tensors, memrefs, and scalars. In a sense, this could be seen as a more refined and slightly higher-level form of vector bitcasting: [RFC][arith] Should we support scalar <--> vector `arith.bitcast`s?

Apologies for responding after some time has passed. I was waiting to set aside some time to properly read through the RFC and then formulate a response. Thanks for the RFC, it does open up some interesting directions, but overall I think it might be doing too much at once, and having a more of gradual approach might be more practical here.

Have been discussing with @antiagainst and @kuhar offline. One thing that seems fairly straight-forward is to have proper support for reading/writing for packed i4 types. Basically means if there is a

%0 = memref.load [..] : memref<...xi4>

have a pass that converts this to

%1 = memref.load [..]. : memref<...xi8>
%2 = .... // extract the relevant bits from `%1` 

The SPIR-V backends already does this since by default SPIR-V only supports i32. (stores are a bit more involved, but thats a separate issue).

The memref.vload @mehdi_amini suggested is also a possibility to make this happen, but that seems more related to readability cause in the end it will bottom out to the same code.

So this supports all power-of-2 sub-byte widths.

I think what is in this RFC is something more general that goes past this simple step (includes padding, etc.) and is providing a way to represent the packing from the user input… So a couple of clarifications

  1. Where do these packed types get introduced?
  2. That will inform the next question of do we need the packed.pack and packed.unpack operations. Specifically the vector of packed type confuses me.

I am still trying to get all the moving pieces here, but the RFC seems pretty heavy weight, and was trying to see if there is a way to make things less intrusive… For example I am not sure vector type needs to learn support for packed types.

1 Like

Sorry for the late reply! I was doing travels and being sick. Thanks @mehdi_amini and @KFAF for the great thoughts! They help to drive into the core of the issue and derive better solutions!

The proposal is written from the perspective of the packed sub-byte type; so it’s inevitably focusing more on the detailed semantics there. I think it’s useful to step back a bit to look at the global picture and overall compilation flow to see more of the motivations and how it fits all together.

One of the major benefits of having such explicit packed sub-byte types is that it allows a consistent view and treatment across different compiler components and pipelines, due to type being able to persist through region boundaries.

To expand on this. The input sources are models authored in Python and quantized and imported into MLIR. At this level we don’t concern memory and storage; we see high-level ops on tensor<*xi4/i3/etc>. Then down the compilation pipeline we need to decide how to organize the data in storage and whether/how to pack/pad those sub-byte elements.

For a holistic compiler like IREE, we need to partition the input model to separate the dense computation, which eventually becomes device kernels, and remaining, which eventually becomes host runtime logic. Such partitioning happens at high-level tensor ops before bufferization and memref and loops. Once partitioned, the host and device tensor ops goes down to different compiler pipelines, without knowing each other. We’d need them to be consistent regarding how the data is packed to avoid subtle correctness issues. Actually, we may also want to generate code for different devices, like CPU/GPU, etc. So for these different targets, we also want them to be consistent, because the host would pass data among them to schedule workload on them together to harness the heterogeneity of SoCs etc.

The above flow and architecture requires us to plan the sub-byte packing/padding ahead of graph partitioning, at the whole-tensor level. At this level, we don’t see memref or load/store operations yet; they are down the dedicated CodeGen pipeline where we only see a portion of the original graph. So we think we need a new packed.bits type given type is able to persist through region boundaries and interfaces.

If we actually just turn all tensor<*xi4/i3/etc> into tensor<*xi8/i16/etc.> before graph partitioning and rely on memref.vload to interpret them after bufferization and generating loops, we have a layering issue of being too late. We run the risk of different CodeGen pipelines treating them differently. We need to duplicate various packing/padding schemes (e.g., 2xi4 → i8, 2xi4->2xi8, etc.) for all host/device targets we care. One also need to be extra careful, e.g., to provide the same command line options for all the target to CodeGen to. To me, it’s fragile and error-prone.

Keeping things in sync “without knowing each other” is usually solved by having ABI conventions for handling the interface between the split components.
But stepping back from the sub-byte packing: isn’t everything you’re describing applicable beyond this? That is even with simple floats, it seems to me separate compilation of tensor<231x433xf32> requires agreement on the layout/padding/strides to use, and this does not seem different from subbyte to me.

I dont think that is the suggestion… We can keep the tensor<*xi4/i3> as is for both host and device… at some point both have to have a consistent lowering for these types… So the same packing padding scheme will be used on the host and all devices will see the same packing. We then need to handle the packed data type load/stores in the backend…

Still unclear about where the packing types get introduced…

1 Like