RFC: Supporting Sub-Channel (or Blockwise) Quantization in MLIR
Introduction & Motivation
MLIR currently supports quantization on a per-tensor and per-channel basis, but sub-channel quantization is increasingly popular as a more efficient and accurate technique. This RFC proposes to introduce sub-channel quantization support in MLIR. Sub-channel quantization involves dividing channels into smaller blocks and quantizing each block independently, leading to potential improvements in model accuracy and performance, particularly for models with varying sensitivity to quantization error across different tensor regions.
The inclusion of sub-channel quantization support in MLIR is motivated by the following current practices and future needs:
- The AI Edge Quantizer team aims to express and emit this type for use in the TFL and StableHLO dialects. While the completion timeline is unclear, thereâs a need to start experimenting with this as soon as possible on models like Gemma.
- The best we can do without a dedicated subchannel type is to emulate subchannel quantization math and pattern match to determine the associated subchannel information, but this does not scale well.
High-Level Proposal
The proposal introduces support for sub-channel quantized types in the MLIR ecosystem, ensuring backward compatibility with existing types. We propose to represent subchannel quantization building off the per-channel representation:
Consider a [6,4]-shaped tensor undergoing quantization along axis-0 and axis-1 with block sizes 1 and 2, respectively. As a result, the shape of the scales (or zero-points) will be [6,4]/[1,2] = [6,2], which essentially represents the number of blocks in each axis.
The proposed type expression for this quantization scheme is:
tensor<6x4x!quant.uniform<i8:f32:{0:1, 1:2}, {{s00:z00, s01:z01}, {s10:z10,s11:z11}, .., {s50:z50,s51:z51}}>>
With the following in-memory representation:
// class definition updates (only relevant fields are shown)
struct UniformQuantizedSubChannelType {
...
ArrayRef<int32_t> quantizationDimensions, blockSizes;
DenseElementsAttr scales, zeroPoints;
...
}
// Per the running example,
// quantizationDimensions : [0,1]
// blockSizes: [1,2]
// scales: [[s00, s01], [s10,s11], .., [s50,s51]] : tensor<6x2xf32>
// zeroPoints: [[z00, z01], [z10,z11], .., [z50,z51]] : tensor<6x2xi8>
Why use the shaped DenseElementsAttr for sub-channel quantization?
We prefer to use DenseElementsAttr for storing scales and zero-points as tensors with a defined shape, rather than simple 1-D vectors, as it simplifies the process of lowering quantize/dequantize operations during compilation. This is especially important when handling tensors with dynamic shapes or unranked tensors.
The Challenge with 1-D Vectors
Letâs consider a tensor of shape [6, 8] that undergoes sub-channel quantization along axes 0 and 1 with block sizes of 2 and 4, respectively. To quantize or dequantize this tensor, we need to determine the appropriate scale and zero-point for each element.
-
Tensor Approach: If the scales and zero-points are stored as a tensor, their shape directly corresponds to the blocks in the quantized tensor. We can easily find the correct scale for an element by dividing its index by the block size (
[i,j]/[2,4]) and using the result to index into the scale tensor. -
1-D Vector Approach: If scales are stored as a 1-D vector, we need to derive the shape of the scale tensor using
shape(tensor)/block_sizes. This works for statically shaped tensors but becomes problematic when dealing with dynamically shaped or unranked tensors where the shape isnât known at compile time.
Consider a dynamically shaped tensor [?, ?] with scales {s0, s1, s2, s3, s4, s5}, the same [i,j]/block_sizes equation will leave us with an index that requires a dynamic reshaping of the scales in order to be valid.
âAppendix > Case 2â shows the quant.qcast op lowering for an unranked tensors and further shows the importance of a known scale shape.
Type Constraints
We can formally enumerate all the constraints on this new type:
-
1 <= block_sizes <= shape(tensor) -
size(block_sizes) = rank(tensor) -
shape(tensor) % block_sizes = 0 -
shape(scales) = shape(zero-points) = shape(tensor) / block_sizes -
element_type(zero_points...) = storage_type -
element_type(scales...) = expressed_type
Notes:
-
In the absence of a block size for a specific axis i, we assume its value to be equal to the
dim(tensor, i). This implies that quantization is performed on a per-tensor basis along axis i and the corresponding dimension size of axis i of the scale (or zero-point) tensor is 1, as given bydim(tensor, i)/block_sizes[i]. -
(3) is not a fundamental limitation, but a design choice. In the event
shape(tensor) % block_sizes != 0, then the shape of scales (or zero-points) can beâshape(tensor)/block_sizesâ. (Feedback is welcome!)
Subchannel quantized types on unranked tensors
The proposed design scales well for unranked tensors as well, this section will explore how unranked tensors use the scale tensor at runtime and how it can be represented at compile time using this proposal.
Runtime scale tensor for unranked tensors
Imagine an unranked tensor where we apply sub-channel quantization along dimensions m and n with block sizes of 2 and 3, respectively. At runtime, when the complete shape of the tensor is revealed, the scale tensor, letâs call it V, would adhere to these rules:
-
V[m] = p: Here, âpâ is calculated by dividing the actual size of dimension âmâ by its block size (which is 2). -
V[n] = q: Similarly, âqâ is derived by dividing the runtime size of dimension ânâ by its block size (which is 3). -
V[i] = 1: For any other dimension âiâ thatâs not âmâ or ânâ, the corresponding size in the scale tensor âVâ is simply 1.
Compile-time scale representation for unranked tensors
Now let us explain how we leverage this runtime properties of V to propose a compile time representation without knowing the full shape.
In the current proposal, we express it as a tensor with a shape of [P x Q]. Itâs essentially a compact version of V where all the dimensions with size 1 are removed or âcollapsedâ. P and Q are constants known at compile time. They act as constraints or upper bounds on the potential sizes of dimensions m and n, respectively, in the actual unranked tensor.
Alternate Design: Infer block size and quantization dimension
In an alternate design we initially thought to use the shape of the scales tensor to determine the block size: shape(data)/shape(scale) = block_size, and quantization dim would be the dimension of the scale tensor that wasnât equal to 1. However, this proposal didnât scale well with dynamism, since the block size could not be statically known without enforcing rank constraints. Therefore we opted to explicitly declare blockSizes and quantizationDimensions in the type.
Integration with Quant dialect ops
The introduction of the type will extend the operational semantics of quant.qcast and quant.dcast operations. The new type will be fully supported by quant.qcast and quant.dcast and will be fully supported in the --lower-quant-ops.
The exact decompositions will be covered in the âAppendix > Lowering qcast and dcast with suchannel quantizationâ section of this doc.
Comparison with other popular quantization schemes
This RFC proposes encoding sub-channel quantization parameters within the type system of the MLIR Quant dialect. This approach contrasts with other quantization schemes, such as the one employed by ONNX, where quantization parameters are explicitly defined within operations as operands. To better understand the design choices in this proposal, itâs helpful to compare and contrast these two strategies:
ONNX: Op-level Quantization Parameters ONNX represents quantization parameters directly within the operations themselves, making them an integral part of the modelâs graph structure. This offers flexibility in expressing diverse quantization techniques, as the parameters are not bound to the static type of the tensor. However, this can lead to more complex model representations due to the increased number of explicit parameter definitions.
MLIR: Type-level Quantization Parameters In contrast, the MLIR Quant dialect, with this proposed sub-channel quantization type, aims to encode the quantization scheme directly within the type system. This method provides a more structured and concise way to represent quantization, potentially leading to simpler model representations and easier analysis. However, itâs important to acknowledge that the statically typed approach of the MLIR Quant dialect may not be suitable for all quantization scenarios. Expressing highly customized quantization strategies, such as those requiring dynamic parameters or non-uniform quantization, could pose challenges within the current type system.
Despite these potential limitations, the proposed sub-channel quantization type is a valuable addition to the MLIR Quant dialect. It directly addresses the needs of its existing user base by providing a more efficient and accurate quantization method for many common use cases.
Maintenance
Building upon the previous quant RFC, alongside MathWorks and others, the StableHLO and ODML teams express a strong interest in the ongoing development and enhancement of the quantization dialect. We have actively participated in the review of that RFCs PR and are committed to:
- Proposing further enhancements to the verification and expressivity of the
quantdialect. - Engaging in ongoing discussions and maintenance efforts related to the
quantdialect.
We believe that collaborative development and maintenance will ensure the continued robustness and relevance of the quant dialect within the MLIR ecosystem.
Looking Ahead: Optimizing Quantized Model Memory Footprint
A few other changes we will likely be proposing to the quant dialect in the future include:
- Allowing
zero_pointandscalebitwidths to be smaller thanstorage_typeandexpressed_type, resp., which will further reduce model footprint. This would have implications on the bytecode format if the types differ as this is new semantics, and the print would need to be enhanced to explicitly state the element type used. - Super-block quantization, allowing scales to be quantized tensors. This proposal interoperates well with the proposed
DenseElementsAttrformat, but we are likely several months away from using this in practice so will not propose its addition yet.
Appendix
Lowering qcast and dcast with suchannel quantization
These operations will adhere to the pseudocode proposed in [RFC] Improvements in the âquantâ dialect, more formally, the scales and zero points can be figured out using the following pseudocode:
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
# either per-axis or sub-channel
q_block_sizes = block_sizes(quantized_type)
for i in index_space(result_type):
zero_points[i] = zero_points(quantized_type)[i/q_block_sizes]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [], type(result_type))
// either per-axis or sub-channel
q_block_sizes = block_sizes(quantized_type)
for i in index_space(result_type):
scales[i] = scales(quantized_type)[i/q_block_sizes]
return scales
The folding and canonicalizing behaviors of the aforementioned operations will not change.
Operation lowering of quant.qcast, quant.dcast using --lower-quant-ops
In the process of lowering optimization, the pass treats quant.dcast and quant.qcast operations similarly. However, to simplify the explanation, the following examples will concentrate solely on lowering quant.qcast. The pass distinguishes between two distinct cases, with each one leading to considerably different code structures. These cases will be explored in detail below:
Case 1: Sub-channel quantization, with ranked input
In sub-channel quantization, different scale and zero-point pairs apply to the different items of the input tensor in the dimension designated as quantization dimensions. This is accomplished through the use of a linalg.generic operation with its affine map attributes designed to extract and apply the correct scale and zero point to each element of the input tensor.
- Input mlir
!qalias = !quant.uniform<i8<-128:127>:f32:{1:2, 3:2}, {{{1.0:1, 2.0:2}},{{3.0:3, 4.0:4}}}>
func.func @f(%arg0: tensor<6x4x6x4xf32>) -> tensor<6x4x6x4x!qalias> {
%0 = "quant.qcast"(%arg0) : (tensor<6x4x6x4xf32>) -> tensor<6x4x6x4x!qalias>
return %0 : tensor<6x4x6x4x!qalias>
}
- Output mlir
!qalias = !quant.uniform<i8<-128:127>:f32:{1:2, 3:2}, {1.0:1, 2.0:2, 3.0:3, 4.0:4}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (0, d1 floordiv 2, 0, d3 floordiv 2)>
func.func @f(%arg0: tensor<6x4x6x4xf32>) -> tensor<6x4x6x4x!qalias> {
// Create tensors of scales and zero points
%cst = arith.constant dense<[[[1.0, 2.0]],[[3.0, 4.0]]]> : tensor<1x2x1x2xf32>
%cst_0 = arith.constant dense<[[[1, 2]],[[3, 4]]]> : tensor<1x2x1x2xi8>
// Traverse input, scales, zero-point, and output tensors
%0 = tensor.empty() : tensor<6x4x6x4xi8>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<6x4x6x4xf32>, tensor<1x2x1x2xf32>, tensor<1x2x1x2xi8>) outs(%0 : tensor<6x4x6x4xi8>) {
^bb0(%in: f32, %in_1: f32, %in_2: i8, %out: i8):
%3 = arith.divf %in, %in_1 : f32
%4 = arith.sitofp %in_2 : i8 to f32
%5 = arith.addf %3, %4 : f32
%6 = arith.fptosi %5 : f32 to i8
%c-128_i8 = arith.constant -128 : i8
%c127_i8 = arith.constant 127 : i8
%7 = arith.maxsi %6, %c-128_i8 : i8
%8 = arith.minsi %7, %c127_i8 : i8
linalg.yield %8 : i8
} -> tensor<6x4x6x4xi8>
%2 = quant.scast %1 : tensor<6x4xi8> to tensor<6x4x6x4x!qalias>
return %2 : tensor<6x4x6x4x!qalias>
}
Case 2: Sub-channel quantization, with unranked input
When dealing with unranked tensors in quantization, a common strategy, as presented in RFC: improvements in the quant dialect, involves reshaping the input into a tensor with a known rank allowing for easier alignment with the quantization parameters. In our case, we flatten the input into an N-dimensional tensor, where N is calculated as 2 * size(quantization_dimensions) + 1. This reshaping allows us to align the tensor dimensions with the quantization parameters, enabling efficient element-wise quantization using linalg.generic. Crucially, we construct affine maps that precisely navigate the reshaped tensor and select the correct scale and zero point for each sub-channel. Finally, we reshape the quantized tensor back to its original form.
- Input mlir
// Note that the scales/zero-points shape is defined, which essentially
// puts upper bounds on the potential sizes of tensor dimensions 1 and 3 as
// dim(arg0, 1) = block-size at axis-1 * dim(scales,1) = 4, and
// dim(arg0, 3) = block-size at axis-1 * dim(scales,3) = 4
!qalias = !quant.uniform<i8<-128:127>:f32:{1:2, 3:2}, {{1.0:1, 2.0:2}, {3.0:3, 4.0:4}}>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
%0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
return %0 : tensor<*x!qalias>
}
- Output mlir
!qalias = !quant.uniform<i8<-128:127>:f32:{1:2, 3:2}, {1.0:1, 2.0:2, 3.0:3, 4.0:4}>
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (0, d1 floordiv 2, 0, d3 floordiv 2, 0)>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
// Save shape of original unranked tensor
%0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
// Calculate tensor size on the left of axis-1, mid of axis 1 and 3, and right of axis-3
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%left_of_axis_1, %ignore1 = "shape.split_at"(%0, %c1) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
%ignore2, %right_of_axis1 = "shape.split_at"(%0, %c2) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
%mid_of_axes_1_3, %ignore3 = "shape.split_at"(%right_of_axis1, %c3) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
%ignore4, %right_of_axis3= "shape.split_at"(%0, %c4) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
%1 = shape.num_elements %left_of_axis_1 : tensor<?xindex> -> index
%2 = shape.num_elements %mid_of_axes_1_3 : tensor<?xindex> -> index
%3 = shape.num_elements %right_of_axis3 : tensor<?xindex> -> index
// Reshape input to 5D tensor
// The fact that we make the scale shape to be defined allowed us to infer the dimension
// sizes of the tensor at the quantization dimensions, which helps simplifying the lowering.
%known_dim_size = arith.constant 4 : index
%from_elements = tensor.from_elements %1, %known_dim_size, %2, %known_dim_size, %3 : tensor<5xindex>
%reshape = tensor.reshape %arg0(%from_elements) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x4x?x4x?xf32>
// Scale and zero-point tensors
%c5 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
%scales = tensor.reshape %c5(%from_elements)
: (tensor<2x2xf32>, tensor<5xindex>) -> tensor<1x4x1x4x1xf32>
%c6 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi8>
%zero-points = tensor.reshape %c6(%from_elements)
: (tensor<2x2xi8>, tensor<5xindex>) -> tensor<1x4x1x4x1xi8>
// Initialize output tensor
%dim_0 = tensor.dim %reshape, %c0 : tensor<?x4x?x4x?xf32>
%dim_2 = tensor.dim %reshape, %c2 : tensor<?x4x?x4x?xf32>
%dim_4 = tensor.dim %reshape, %c4 : tensor<?x4x?x4x?xf32>
%4 = tensor.empty(%dim_0, %dim_2, %dim_4) : tensor<?x4x?x4x?xi8>
// Quantize
%5 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel","parallel", "parallel"]} ins(%reshape, %scales, %zero-points : tensor<?x4x?x4x?xf32>, tensor<1x4x1x4x1xf32>, tensor<1x4x1x4x1xi8>) outs(%4 : tensor<?x4x?x4x?xi8>) {
^bb0(%in: f32, %in_7: f32, %in_8: i8, %out: i8):
%6 = arith.divf %in, %in_7 : f32
%7 = arith.sitofp %in_8 : i8 to f32
%8 = arith.addf %6, %7 : f32
%9 = arith.fptosi %8 : f32 to i8
linalg.yield %9 : i8
} -> tensor<?x4x?x4x?xi8>
// Reshape output to original unranked tensor
%reshape_1 = tensor.reshape %5(%0) : ( tensor<?x4x?x4x?xi8>, tensor<?xindex>) -> tensor<*xi8>
%6 = quant.scast %reshape_1 : tensor<*xi8> to tensor<*x!qalias>
return %6 : tensor<*x!qalias>
}
Feedback
Thank you for reading this RFC, all feedback is welcome and looking forward to discussion!