Summary
We propose to add support for sub-byte types (e.g., i4, i3, i2) packed into byte-aligned values. This will be implemented through a new dialect, packed
, that defines (parametric) packed types, e.g., packed.bits<8:2xi4>
, and a small number of operations over these types; initially focusing on packing, unpacking, casting, insertion/extraction for memory interaction.
Background
Research has been pushing on sub-byte quantization for quite some time. Now with the recent advances of large language models (LLMs), we see an even greater need to support quantization/weight compression through small integer types, e.g., int4, int3 in GPTQ. These can help to reduce the size of LLMs and make them more efficient to run. Especially, on memory constrained platforms like mobile/edge devices, supporting i4/i3/i2 types determines whether we can run these models at all.
The difficulty that comes with these sub-byte types is that memory is typically byte-addressable and that the existing hardware does not expose instructions to manipulate those types. In practice, this means that we have to pack such integer types into larger types supported by the platform of choice. For example, we can densely pack two i4 values in an i8 for storage, and perform calculations by first unpacking each sub-value into i8 and later packing both results back to a single i8 with some masking/shifting. While this solves the problem of storage, it does not lead to the best arithmetic performance without a dedicated hardware compute unit. So there exist more packing schemes that give better arithmetic performance, e.g., ULPPack, but come at an increased complexity of not being densely packed, e.g., three i3s in a single i16.
Proposal
We propose to solve the problem of sub-byte types by introducing a new parametric type containing multiple sub-values: packed.bits<N-Bits:KxIntOrFloat:[PackingScheme]>
. The key property is that such packed types are byte-addressable; they hold necessary information in an integrated way, and do not require premature unpacking or sub-byte calculations until deemed suitable. packed.bits
has elementwise semantics. Unlike these shaped types, all components together form a single âlogicalâ value / register; for example, inserting a poisoned sub-value would poison the whole packed value.
Ops:
-
packed.pack
â creates a packed value out legal/wide scalar components -
packed.unpack
â extracts all sub-values as legal/wide scalar components -
packed.cast
â converts between packed values and their storage representation -
packed.insert
â returns a new packed value with a single sub-value inserted -
packed.extract
â extracts a single sub-value as a legal/wide scalar
Examples:
// Types:
// packed.bits<8:2xi4:[0,4]> // Explicit form, two sub-values at bit-offsets 0 and 4.
// packed.bits<8:2xi4> // Same as above, dense packing is the default.
// packed.bits<16:3xi4:[0,8,12]> // Sub-values start at bit-offsets 0, 8, and 12.
// packed.bits<32:2xf16> // Also floats but not index.
// vector<4xpacked.bits<8:2xi4>>
// tensor<1920x1080x3xpacked.bits<8:4xi2>>
// memref<?x?xpacked.bits<8:2xi4>>
// Ops:
func.func scalar(%arg: packed.bits<8:2xi4>) -> () {
%a = packed.unpack %arg packed.bits<8:2xi4> to vector<2xi8>
%b:2 = packed.unpack %arg packed.bits<8:2xi4> to (i8, i8)
%c = packed.pack %a : vectro<2xi8> to packed.bits<8:2xi4>
%d = packed.pack %b#0, %b#1 : i8 to packed.bits<8:2xi4>
%e = packed.cast %c : packed.bits<8:2xi4> to i8
%f = packed.cast %e : i8 to packed.bits<8:2xi4>
%x = packed.insert %y#0 into %d[0] : i8, packed.bits<8:2xi4>
%y = packed.extract %x[1] : packed.bits<8:2xi4>, i8
}
func.func elementtype(%arg: memref<?x?xpacked.bits<8:2xi4>) -> () {
%0 = vector.transfer_read %arg ... : memref<?x?xpacked.bits<8:2xi4>>,
vector<4xpacked.bits<8xi4>>
%1 = vector.transfer_read %arg ... : memref<?x?xpacked.bits<8:2xi4>>,
vector<4xpacked.bits<8xi4>>
%a = packed.unpack %0 vector<4xpacked.bits<8:2xi4>> to vector<4x2xi8>
%b:2 = packed.unpack %1 vector<4xpacked.bits<8:2xi4>> to (vector<4xi8>, vector<4xi8>)
%c = packed.pack %b#0, %b#1 : vector<4xi8> to vector<4xpacked.bits<8:2xi4>>
%d = packed.cast %c : vector<4xpacked.bits<8:2xi4>> to vector<4xi8>
%e = packed.cast %d : vector<4xi8> to packed.bits<8:2xi4>
%x = packed.insert %b#0 into %c[0] : vector<4xi8>, vector<4xpacked.bits<8:2xi4>>
%y = packed.extract %x[1] : vector<4xpacked.bits<8:2xi4>>, vector<4xi8>
}
These packed operations can be lowered to arithmetic over the storage type through masking/shifting. We do not propose packed constants, as these can be constructed through packed.pack
or packed.cast
of integer constants.
In the future, we may envision more operations like packed.repack
to convert between different packed representations, and dedicated ops for efficient arithmetic operations over packed formats, e.g., pack.dotprod
, packed.reduce
. Those can be added when the needs arise.
Interactions with Other Dialects
This dialect is meant to be at a root level of the dialect tree â other higher level dialects can depend on it. We anticipate that packed.bits
types compose nicely with vector
, tensor
, and memref
types by being an allowed element type there. As a by-product, this gives us ways to improve memref
âs memory storage story regarding sub-byte types.
Initially, we will support computation with regular arith/math operations over the storage type by âenteringâ and âexitingâ the packed.bits
world via dedicated construction/destruction ops: packed.pack
and packed.unpack
. When chained, the resulting pack/unpack operations should cancel out.This enables tightly controlled lowering procedures.
In the future, we could define the semantics for elementwise arith/math operations over packed types via elementwise ops over the sub-value types.
Alternatives Considered
The two other most natural representations are using the storage type directly (e.g., i8) or vectors or sub-byte types (e.g., vector<...x2xi4>
or vector<...xvector<2xi4>>
). These two approaches come with the following drawbacks:
- Performing arithmetic over integer types leads to difficult to read and optimize code. We lose the high-level information about the actual element types and their value bounds. In addition, the experience with the wide integer emulation pass (the opposite direction) suggests that there will be a lot of early bloat in the constants and arithmetic operations for both packing/unpacking and controlling overflows.
- Neither native integers nor vectors offer a solution for non-dense packing schemes as the memory accesses would not be byte-aligned.
- Even for densely-packed vectors, we would be adding to the already overloaded vector semantics: there would be no way to differentiate normal dimensions from dimensions used for bit packing. At the lowest level of lowering, vectors are typically one-dimensional and should correspond to individual hardware registers / vector lanes. Adding an extra dimension or level of nesting breaks such lowering paths, specifically, dimension mapping during tiling and vectorization.
Looking forward to your feedback.
-Jakub (@kuhar) and Lei (@antiagainst)