Introduction
This RFC proposes a refinement of the MLIR quant dialect and its tooling. The proposal includes:
-
An improved specification for ops
quant.dcast,quant.qcast, andquant.scastwith custom assembly format, clear semantics, strict syntactic verification, and canonicalization. -
New pass
--lower-quant-opsto lower opsquant.qcastandquant.dcastto other core dialects, with support for statically shaped, dynamically shaped, and unranked tensors, as well as per-layer and per-channel quantization. -
New pass
--strip-func-quant-typesto eliminate quantization information from function headers. -
A workflow example illustrating how to integrate the proposed passes and canonicalization patterns in a transform pipeline that lowers a generic machine learning dialect with quantization support.
-
A specific application in the domain of the Tensorflow Lite (
tfl) dialect, based on the design of a fallback strategy to lower currently unsupported quantized types. -
A working implementation with comprehensive documentation and unit tests, available on this pull request.
The quant dialect
This section is aimed at describing the quant dialect according to the features proposed in this document. The content of this section is used as the TableGen documentation for the dialect in the associated pull request. It includes the following material:
-
An introduction to the
quantdialect. -
A thorough description of the
!quant.uniformtype as currently supported by the proposed lowering passes, including per-layer and per-channel quantization variants. -
A section titled Per-axis Quantization Integrity describes the requirements for using the
!quant.uniformtype with per-axis quantization. This section is referenced in the documentation for each operation to clarify that these syntactic requirements apply to all operations using per-axis quantization.
Introduction
The quant dialect offers a framework for defining and manipulating
quantized values. Central to this framework is the !quant.uniform data
type, used to represent quantized values. This dialect also provides a
suite of operations to handle and convert quantized values between their
original floating-point representations and the optimized, lower bit-width
integer representations. The quant dialect is instrumented with
transformation passes to lower these operations into other core MLIR
dialects, while also flattening all occurrences of quantized types into
their integer counterparts.
The !quant.uniform type
The quantization process establishes a relationship between two types of
values: an expressed value and a stored value. The former refers to the
floating-point representation used in an original machine learning model,
capturing the precise numerical characteristics needed for accurate
calculations. The latter is the simplified integer representation that
resides in memory after quantization. The !quant.uniform data type
encodes the necessary information for (lossy) round-trip conversion between
an expressed and a stored value.
The quant.uniform type has two variants: per-layer quantization and
per-channel (or per-axis) quantization. In per-layer quantization, the
quantization information affects an entire tensor uniformly. Conversely, in
per-channel quantization, the data type encodes the specific tensor axis
that serves as the channel and includes quantization information for each
individual channel within the tensor. Below are the specific syntactic and
semantic considerations for each modality.
Per-layer quantization
This is the general syntax of the !quant.uniform type representing
per-layer quantization:
`!quant.uniform` `<`
storedType (`<` storageMin `:` storageMax `>`)? `:`
expressedType `,`
scale (`:` zeroPoint)?
`>`
The type contains the following parameters:
-
storedType: Integer type of the value stored in memory. This type
conveys the bit width and signedness of the quantized stored value.
Signed integer types are represented as'i' bitWidth(e.g.,i8),
while unsigned integer types are represented as'u' bitWidth(e.g.,
u8). -
storageMin,storageMax: Optional bounds for the stored value. If
given, they must be within the range ofstoredType. If omitted, the
entire range ofstoredTypeis allowed (e.g.,-128...127fori8or
0...255foru8). -
expressedType: Floating-point type of the value expressed by this
quantized type. -
scale: Floating-point value of typeexpressedTypeused in the
conversion between stored and expressed values. -
zeroPoint: Optional integer value of typestorageTypeused in the
conversion between stored and expressed values. If omitted, the default
is 0.
Type conversions, rounding methods, and clamping actions aside, the
relationship between the expressed and stored values as encoded in a
quantized type is denoted by the following formula:
expressedValue = (storedValue - zeroPoint) * scale
Operations quant.qcast (quantize cast) and quant.dcast (dequantize
cast) can be used to quantize a floating-point value and dequantize a
stored value, respectively. See the documentation for these operations for
details on how the quantization and dequantization processes are influenced
by the !quant.uniform type parameters.
Here are some examples of the use of !quant.uniform with per-layer
quantization:
// An 8-bit signed integer type is used to represent a 32-bit float. No
// clamping information is provided, so the full [-128, 127] range is
// available. The scale is set to 3.0, and the zero point takes its default
// 0 value.
!quant.uniform<i8:f32, 3.0>
// A 16-bit unsigned integer type is used to represent a 32-bit float. Out
// of the 16 bits, only 10 are used, acoording to the 0..1023 clamping
// range. The type sets the scale to 1.23 and the zero point to 512.
!quant.uniform<u16<0:1023>:f32, 1.23:512>
Per-channel quantization
The general syntax of the !quant.uniform type representing per-channel
quantization is as follows:
`!quant.uniform` `<`
storedType (`<` storageMin `:` storageMax `>`)? `:`
expressedType `:`
channelAxis `,`
`{`
scale0 (`:` zeroPoint0)? `,`
scale1 (`:` zeroPoint1)? ...
'}'
`>`
In this data type, there are multiple pairs of scale and zeroPoint
values. The channelAxis field represents the dimension of the containing
tensor acting as the channel. The size of the tensor along this dimension
is expected to match the number of provided scale-zeroPoint pairs, and
a given pair i applies to all elements in the tensor whose index along
dimension channelAxis is i. A quantized data type using per-channel
quantization is always expected to be contained within a tensor type.
Here are some examples:
// A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
// floats. Dimension 1 of the tensor acts as the channel dimension. Its
// size 3 matches the number of provided scale values. Tensor elemenets at
// positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
// 5.0, respectively.
tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
// A 2D dynamically sized tensor contains 16-bit unsigned integers
// representing 32-bit floats. Dimension 0 of the tensor acts as the
// channel dimension. Since 2 scale and zero-point values are provided, the
// size of dimension 0 is expected to be 2 at runtime. Tensor elements
// [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale
// 3.0 and zero point 20.
tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
Per-axis quantization integrity
When type !quant.uniform contains per-axis quantization information, the
rules below are enforced. These rules guarantee that the quantization
information encoded in the data type is applicable to the context in which
the quantized type is used. For efficiency, these rules are actively
enforced by the verifiers of quant dialect ops, but they must be
respected in any context in which the !quant.uniform data type is used,
such as the header of a func.func op, or the input of an arithmetic
operation.
- A quantized type with per-channel quantization information must be the
element type of a tensor container type, and may not occur directly as
the data type of a scalar value.
// Incorrect. Type !quant.uniform specifies per-channel quantization for a
// scalar type.
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>
// Correct. Type `!quant.uniform` with per-channel quantization is wrapped
// in a `tensor` type.
%result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
- If the tensor containing the
!quant.uniformtype is ranked, its rank
must be greater than the channel axis specified in the quantized type.
// Incorrect. The tensor rank (2) is not greater than the channel axis in
// the quantized type (3).
%result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>
// Correct. The tensor rank (2) is now greater than the channel axis (1):
%result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
- If the axis dimension in the containing tensor is static, its size must
be equal to the number of scales present in the quantized type.
// Incorrect. The channel axis is 1, and the size of dimension 1 in the
// containing tensor is 3. However, there are 4 scale values present in the
// quantized type.
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
// Correct. The quantized type now includes 3 scale values, matching the
// size of dimension 1 of the result tensor.
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
Operation quant.dcast
This section proposes a thorough specification for operation quant.dcast, with updated syntax, semantics, and canonicalization. This documentation is included in the TableGen operation definition on the associated pull request.
Syntax
The following custom assembly format is introduced:
operation ::= `quant.dcast` $input attr-dict `:` type($input) `to` type($result)
Semantics
Convert an input quantized value into its expressed floating-point value.
The dequantization process consists of the following steps:
def dequantize(quantizedValue: quantizedType) -> expressedType:
storedValue = reinterpretCast(quantizedValue, storageType)
storedValueFloat = convertIntToFloat(storedValue, expressedType)
zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
expressedValue = (storedValueFloat - zeroPointFloat) * scale
return expressedValue
Here, storageType, expressedType, scale, and zeroPoint are obtained
from the corresponding parameters encoded in quantizedType. For
per-channel quantization, the appropriate scale and zeroPoint values
are used for each tensor element computation according to the channel the
element belongs to.
Verification
The operation must satisfy the following syntactic constraints:
-
Operand
inputmust be a scalar or tensor of type!quant.uniform. -
The result type must be a floating-point scalar or tensor.
-
The
expressedTypeparameter of the!quant.uniformtype of the input
must match the floating-point type of the result. -
The operand and result types must be both scalars or both tensors. If
tensors, they must be both ranked or both unranked. If ranked, both must
have the same shape, including matching static and dynamic dimensions. -
If the operand uses per-channel quantization, its
!quant.uniformtype
must adhere to the Per-axis quantization
integrity guidelines.
Examples
- Dequantize a scalar quantized value
%result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
- Dequantize a dynamically shaped tensor of quantized values
%result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
- Dequantize an unranked tensor using per-axis quantization information
%result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
Canonicalization
If the operand of a quant.dcast op is produced by a quant.qcast op and the operand type of quant.qcast matches the result type of quant.dcast , the value produced by quant.dcast is replaced with the operand of quant.qcast . In the likely case that quant.dcast is the only consumer of quant.qcast , this transformation renders quant.qcast dead code, allowing for further simplification.
- Input IR
%1 = quant.qcast %0 : f32 to !quant.uniform<i8:f32, 2.0>
%2 = quant.dcast %1 : !quant.uniform<i8:f32, 2.0> to f32
- Output IR
%1 = quant.qcast %0 : f32 to !quant.uniform<i8:f32, 2.0> // Possibly dead code
%2 = %0 // I.e., uses of %2 replaced with %0
Notice that the application of this canonicalization pattern may produce code that yields slightly different numerical results due to the lossy nature of the quant.qcast op. While the input IR converts an input floating-point value into a lower-precision integer representation just to immediately reverse this process, the output IR lets the original floating-point value flow through to the following consumers. It is here assumed that in the context of a quantized model, such slight numerical mismatch is tolerable, and that avoiding the rounding error incurred by the unnecessary quantization of an intermediate result is, in fact, desirable.
Operation quant.qcast
This section proposes a thorough specification for operation quant.qcast, with updated syntax, semantics, and canonicalization. This documentation is included in the TableGen operation definition on the associated pull request.
Syntax
The following custom assembly format is introduced:
operation ::= `quant.qcast` $input attr-dict `:` type($input) `to` type($result)
Semantics
Convert a floating-point value to a quantized type. The quantization
process consists of the following steps:
def quantize(expressedValue: expressedType) -> quantizedType:
zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
scaledValue = expressedValue / scale
storedValueFloat = scaledValue + zeroPointFloat
storedValue = convertFloatToInt(storedValueFloat, storageType)
storedValueClamped = clamp(storedValue, storageMin, storageMax)
quantizedValue = reinterpretCast(storedValueClamped, quantizedType)
return quantizedValue
Here, storageType, storageMin, storageMax, expressedType, scale,
and zeroPoint are obtained from the corresponding parameters encoded in
quantizedType. For per-channel quantization, the appropriate scale and
zeroPoint values are used for each tensor element computation according
to the channel the element belongs to.
Verification
The operation must satisfy the following syntactic constraints:
-
Operand
inputmust be a floating-point scalar or tensor. -
The result type must be a scalar or tensor of type
!quant.uniform. -
The
expressedTypeparameter in the!quant.uniformtype of the result
must match the floating-point type of the input. -
The operand and result types must be both scalars or both tensors. If
tensors, they must be both ranked or both unranked. If ranked, both must
have the same shape, including matching static and dynamic dimensions. -
If the result uses per-channel quantization, its
!quant.uniformtype
must adhere to the Per-axis quantization
integrity guidelines.
Examples
- Quantize a scalar floating-point value
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32, 2.0>
- Quantize a dynamically shaped tensor of quantized values
%result = quant.qcast %input : tensor<?xf32> to tensor<?x!quant.uniform<i8:f32, 2.0>>
- Quantize an unranked tensor using per-axis quantization information
%result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
Canonicalization
If the operand of a quant.qcast op is produced by a quant.dcast op and the operand type of quant.dcast matches the result type of quant.qcast, the SSA value produced by quant.qcast is replaced with the operand of quant.dcast. In the likely case that quant.qcast was the only consumer of quant.dcast, this transformation renders quant.dcast dead code, allowing for further simplification.
- Input IR
%1 = quant.dcast %0 : !quant.uniform<i8:f32, 2.0> to f32
%2 = quant.qcast %1 : f32 to !quant.uniform<i8:f32, 2.0>
- Output IR
%1 = quant.dcast %0 : !quant.uniform<i8:f32, 2.0> to f32 // Possibly dead code
%2 = %0 // I.e., uses of %2 replaced with %0
Operation quant.scast
This section proposes a thorough specification for operation quant.scast, with updated syntax, semantics, and canonicalization. This documentation is included in the TableGen operation definition on the associated pull request.
Syntax
The following custom assembly format is introduced:
operation ::= `quant.scast` $input attr-dict `:` type($input) `to` type($result)
Semantics
Convert a value from a quantized type to the corresponding signless integer
storage type, or vice versa. This conversion simply involves a
reinterpretation of the input bits and does not involve any data
manipulation.
This operation is semantically equivalent to builtin.unrealized_conversion_cast. It shares the following properties:
-
It performs a bitwise cast with no data conversion.
-
There is no rewrite pattern that converts this operation into lower level dialects. It acts as a temporary type conversion of values handled by interacting transform passes (i.e.,
--lower-quant-opsand--strip-func-quant-types). After these passes complete, all occurrences of this operation should cancel each other out through the application of canonicalization patterns and dead code elimination.
Verification
The following syntactic restrictions must be met:
-
Operand
inputmust be a scalar or tensor of a signless integer or
!quant.uniformtype. -
The result must be a scalar or tensor of a signless integer or
!quant.uniformtype. -
If the operand is a scalar or tensor of type integer, the result must be
a scalar or tensor of type!quant.uniform, and vice versa. -
The operand and result must be both scalars or both tensors. If tensors,
they must be both ranked or both unranked. If ranked, both must have the
same shape, including matching static and dynamic dimensions. -
The width of the
storageTypeparameter of the quantized type of the
operand or result must match the width of the signless integer type of
the operand or result. -
If the operand or result uses per-channel quantization, its
!quant.uniformtype must adhere to the Per-axis quantization
integrity guidelines.
Examples
- Cast a scalar quantized value into its storage type
%result = quant.scast %input : !quant.uniform<i8:f32, 2.0> to i8
- Cast a dynamically shaped tensor of quantized values into their storage type
%result = quant.scast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xi8>
- Cast an unranked tensor of signless integers into a quantized type using per-channel quantization
%result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
Canonicalization
In a sequence of 2 quant.scast ops where the second consumes the value produced by the first, the result of the second may be replaced with the operand of the first if their data types match. This conversion may render the first op dead.
- Input IR
%1 = quant.scast %0 : i8 to !quant.uniform<i8:f32, 2.0>
%2 = quant.scast %1 : !quant.uniform<i8:f32, 2.0> to i8
- Output IR
%1 = quant.scast %0 : i8 to !quant.uniform<i8:f32, 2.0> // Possibly dead code
%2 = %0 // I.e., uses of %2 replaced with %0
Pass --lower-quant-ops
This pass lowers all instances of quant.dcast and quant.qcast within a function body. Although the pass marks these operations as illegal, it does not declare the entire quant dialect illegal, as the lowering patterns are expected to emit quant.scast operations.
The pass handles quant.dcast and quant.qcast operations similarly, but for simplicity, the examples below will focus solely on the lowering of quant.qcast. The pass differentiates between four cases, each resulting in significantly different code structures, which are detailed next.
Case 1. Per-layer quantization, scalar or ranked input
The IR pattern shown below is emitted when the quantized type of the lowered op uses per-layer quantization. A similar sequence of ops is generated regardless of whether the input is a scalar value or a ranked tensor.
- Input IR
!qalias = !quant.uniform<i8<-8:7>:f32, 2.0:10>
func.func @f(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
%0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias>
return %0 : tensor<3x5x!qalias>
}
- Output IR
func.func @f(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
// Create scale tensor.
// NOTE: All 'arith.constant' + 'tensor.splat' ops will be canonicalized into
// a single 'arith.constant' for statically shaped tensors.
%cst = arith.constant 2.000000e+00 : f32
%splat = tensor.splat %cst : tensor<3x5xf32>
// Divide by scale
%0 = arith.divf %arg0, %splat : tensor<3x5xf32>
// Create zero point float tensor
%c10_i8 = arith.constant 10 : i8
%splat_0 = tensor.splat %c10_i8 : tensor<3x5xi8>
%1 = arith.sitofp %splat_0 : tensor<3x5xi8> to tensor<3x5xf32>
// Add zero point
%2 = arith.addf %0, %1 : tensor<3x5xf32>
// Convert stored value to integer
%3 = arith.fptosi %2 : tensor<3x5xf32> to tensor<3x5xi8>
// Clamp stored value
%c-8_i8 = arith.constant -8 : i8
%c7_i8 = arith.constant 7 : i8
%splat_1 = tensor.splat %c-8_i8 : tensor<3x5xi8>
%splat_2 = tensor.splat %c7_i8 : tensor<3x5xi8>
%4 = arith.maxsi %3, %splat_1 : tensor<3x5xi8>
%5 = arith.minsi %4, %splat_2 : tensor<3x5xi8>
// Cast stored value to quantized type
%6 = quant.scast %5 : tensor<3x5xi8> to tensor<3x5x!qalias>
return %6 : tensor<3x5x!qalias>
}
Case 2. Per-layer quantization, unranked input
The main challenge when dealing with unranked tensors in the emitted IR on the previous case is the generation of scale and zero-point tensor splats whose total size and dimensions match the operation input for type compatibility in arithmetic computations. The strategy used here to deal with unranked tensors is to first flatten them to 1D dynamically sized tensors, query their total size, generate 1D constant tensor splats for scale and zero point, run arithmetic computations on the 1D tensors, and convert the result back into an unranked tensor of the original shape.
Shape-related computations leverage the shape dialect with the goal of maximizing the opportunities for code simplification across IR patterns emerging from consecutive quantization ops. Additional canonicalization patterns have been recently added in the shape and tensor dialect to facilitate such simplifications (see New canonicalization patterns for shape.shape_of and tensor.reshape).
- Input IR
!qalias = !quant.uniform<i8:f32, 2.0:10>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
%0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
return %0 : tensor<*x!qalias>
}
- Output IR
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
// Compute shape and size of input tensor
%0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
%1 = shape.num_elements %0 : tensor<?xindex> -> index
// Reshape input to 1D dynamically sized tensor
%from_elements = tensor.from_elements %1 : tensor<1xindex>
%reshape = tensor.reshape %arg0(%from_elements) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Now we know how to lower a ranked tensor
%cst = arith.constant 2.000000e+00 : f32
%c10_i8 = arith.constant 10 : i8
%c0 = arith.constant 0 : index
%dim = tensor.dim %reshape, %c0 : tensor<?xf32>
%splat = tensor.splat %cst[%dim] : tensor<?xf32>
%2 = arith.divf %reshape, %splat : tensor<?xf32>
%splat_0 = tensor.splat %c10_i8[%dim] : tensor<?xi8>
%3 = arith.sitofp %splat_0 : tensor<?xi8> to tensor<?xf32>
%4 = arith.addf %2, %3 : tensor<?xf32>
%5 = arith.fptosi %4 : tensor<?xf32> to tensor<?xi8>
// Convert the result back into an unranked tensor
%reshape_1 = tensor.reshape %5(%0) : (tensor<?xi8>, tensor<?xindex>) -> tensor<*xi8>
%6 = quant.scast %reshape_1 : tensor<*xi8> to tensor<*x!qalias>
return %6 : tensor<*x!qalias>
}
Case 3. Per-channel quantization, ranked input
In per-channel quantization, different scale and zero-point pairs apply to the different items of the input tensor in the dimension designated as the channel. This is accomplished through the use of a linalg.generic operation with its affine map attributes carefully designed to apply the correct scale and zero point to each element of the input tensor.
- Input IR
!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
func.func @f(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
%0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias>
return %0 : tensor<4x2x5x!qalias>
}
- Output IR
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1)>
func.func @f(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
// Create tensors of scales and zero points
%cst = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
%cst_0 = arith.constant dense<0> : tensor<2xi8>
// Traverse input, scales, zero-point, and output tensors
%0 = tensor.empty() : tensor<4x2x5xi8>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%0 : tensor<4x2x5xi8>) {
^bb0(%in: f32, %in_1: f32, %in_2: i8, %out: i8):
// Divide by scale
%3 = arith.divf %in, %in_1 : f32
// Convert zero point to float
%4 = arith.sitofp %in_2 : i8 to f32
// Add zero point
%5 = arith.addf %3, %4 : f32
// Convert stored value to integer
%6 = arith.fptosi %5 : f32 to i8
// Clamp stored value
%c-8_i8 = arith.constant -8 : i8
%c7_i8 = arith.constant 7 : i8
%7 = arith.maxsi %6, %c-8_i8 : i8
%8 = arith.minsi %7, %c7_i8 : i8
linalg.yield %8 : i8
} -> tensor<4x2x5xi8>
// Cast stored values into quantized type
%2 = quant.scast %1 : tensor<4x2x5xi8> to tensor<4x2x5x!qalias>
return %2 : tensor<4x2x5x!qalias>
}
Case 4. Per-channel quantization, unranked input
Handling unranked tensors and per-channel quantization poses an additional challenge. We need to determine the position of the channel within an unranked tensor without knowing the exact number of dimensions on either side. However, we can assume that the channel dimension exists in the input tensor, or else we would be dealing with malformed IR.
The strategy for dealing with unranked tensors involves flattening the input into a 3D tensor. Everything to the left of the channel dimension is collapsed into the left-most of the reshaped tensor. The channel dimension itself becomes the middle dimension of the reshaped tensor, and everything to the right of the channel dimension is collapsed into the right-most dimension of the reshaped tensor. The arithmetic operations related to quantization are performed on this reshaped 3D tensor. Finally, the result is reshaped back into the original unranked tensor.
- Input IR
!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
%0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
return %0 : tensor<*x!qalias>
}
- Output IR
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1)>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
// Save shape of original unranked tensor
%0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
// Calculate tensor size on the left of channel dimension
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%head, %tail = "shape.split_at"(%0, %c2) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
%1 = shape.num_elements %head : tensor<?xindex> -> index
// Calculate tensor size on the right of channel dimension
%head_0, %tail_1 = "shape.split_at"(%0, %c3) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
%2 = shape.num_elements %tail_1 : tensor<?xindex> -> index
// Reshape input to 3D tensor
%c3_2 = arith.constant 3 : index
%from_elements = tensor.from_elements %1, %c3_2, %2 : tensor<3xindex>
%reshape = tensor.reshape %arg0(%from_elements) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x3x?xf32>
// Scale and zero-point tensors
%cst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
%cst_3 = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
%c0 = arith.constant 0 : index
// Initialize output tensor
%dim = tensor.dim %reshape, %c0 : tensor<?x3x?xf32>
%c2_4 = arith.constant 2 : index
%dim_5 = tensor.dim %reshape, %c2_4 : tensor<?x3x?xf32>
%3 = tensor.empty(%dim, %dim_5) : tensor<?x3x?xi8>
// Quantize
%4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %cst, %cst_3 : tensor<?x3x?xf32>, tensor<3xf32>, tensor<3xi8>) outs(%3 : tensor<?x3x?xi8>) {
^bb0(%in: f32, %in_7: f32, %in_8: i8, %out: i8):
%6 = arith.divf %in, %in_7 : f32
%7 = arith.sitofp %in_8 : i8 to f32
%8 = arith.addf %6, %7 : f32
%9 = arith.fptosi %8 : f32 to i8
linalg.yield %9 : i8
} -> tensor<?x3x?xi8>
// Reshape output to original unranked tensor
%reshape_6 = tensor.reshape %4(%0) : (tensor<?x3x?xi8>, tensor<?xindex>) -> tensor<*xi8>
%5 = quant.scast %reshape_6 : tensor<*xi8> to tensor<*x!qalias>
return %5 : tensor<*x!qalias>
}
Pass --strip-func-quant-types
This module-level pass performs the following actions:
-
It modifies the signature of every top-level
func.funcop in the module by substituting every occurrence of the!quant.uniformdata type with its storage type. This applied to both function declarations and definitions. -
For each input argument in a function definition, a
quant.scastop is inserted at the top of the function body, converting the new argument type into the original!quant.uniformtype. All uses of the original function argument are replaced with the SSA value produced by the newquant.scastop. -
Every occurrence of the
returnop in the body of the function is prepended with aquant.scastop for each return value of type!quant.uniform, converting the quantized type into its storage value. The occurrence of the quantized value in thereturnop is replaced with the SSA value produced by thequant.scastop. -
All occurrences of
func.callandfunc.call_indirectare adjusted to the new signature of the invoked functions, which now lack any occurrence of quantized values. For every quantized argument originally passed to a function call, an additionalquant.scastis introduced, and its integer result is passed to the function call instead of the original quantized value. Similarly, aquant.scastop is also introduced for every quantized result generated by the function call, converting the integer type returned by the new function signature back into the original quantized type.
The following example illustrates the application of this pass on a function with 2 inputs and 1 output, all of them of a quantized type.
- Input IR
!qalias = !quant.uniform<i8:f32, 1.0>
func.func @predict(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> tensor<3x!qalias> {
%sum = "ml.add"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
return %sum : tensor<3x!qalias>
}
- Output IR
!qalias = !quant.uniform<i8:f32, 1.0>
func.func @predict(%arg0_stripped: tensor<3xi8>, %arg1_stripped: tensor<3xi8>) -> tensor<3xi8> {
// Conversion of function arguments
%arg0 = "quant.scast"(%arg0_stripped): (tensor<3xi8>) -> tensor<3x!qalias>
%arg1 = "quant.scast"(%arg1_stripped): (tensor<3xi8>) -> tensor<3x!qalias>
// Function body
%sum = "ml.add"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
// Conversion of return values
%sum_stripped = "quant.scast"(%sum): (tensor<3x!qalias>) -> tensor<3xi8>
return %sum_stripped : tensor<3xi8>
}
Workflow example
Consider a hypothetical user-defined machine learning dialect ml supporting quantized types. Consider also a user-defined pass --convert-ml-to-arith in charge of lowering ml ops into the arith dialect, while accounting for the possibility for op arguments to be of various -possibly different- quantized types.
- In the following input IR, function
@multiply_adduses hypothetical opsml.mulandml.addto process the 1D quantized tensors supplied as function arguments, while the result is provided as the function’s return value.
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>
func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {
%product = "ml.mul"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
%sum = "ml.add"(%product, %arg2) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
return %sum : tensor<3x!qalias>
}
- Pass
--convert-ml-to-arithis expected to handle any combination of quantized types accepted by themlop verifiers. Let us assume, however, that per-channel quantization, as used in this example, is currently not specifically dealt with by the pass, either because it is a temporarily unsupported feature, or because the infrequent occurrence of such case makes the benefits of an optimized lowering based on quantized integer arithmetic not worth the costly implementation effort. The pass may then rely on a dequantization approach for unsupported corner cases, producing the code below. Each lowering pattern in the pass dequantizes input operands, applies the default floating-point lowering pattern for themlop, and quantizes the floating-point result back to match the original result type.
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>
func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {
%arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
%arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
%product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>
%product = "quant.qcast"(%product_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
%product_f32_1 = "quant.dcast"(%product) : (tensor<3x!qalias>) -> tensor<3xf32>
%arg2_f32 = "quant.dcast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xf32>
%sum_f32 = arith.addf %product_f32_1, %arg2_f32 : tensor<3xf32>
%sum = "quant.qcast"(%sum_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
return %sum : tensor<3x!qalias>
}
- Applying lowering patterns on individual ops produces redundant conversions that can be optimized before applying further transformations. In the code above, the quantization of the result of
arith.mulfis immediately dequantized in preparation for the followingarith.addf. Applying a canonicalization pass at this point eliminates this redundancy, leading to the following code.
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>
func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {
%arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
%arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
%product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>
%arg2_f32 = "quant.dcast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xf32>
%sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>
%sum = "quant.qcast"(%sum_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
return %sum : tensor<3x!qalias>
}
- The inserted
quantops may now be lowered with the proposed--lower-quant-opspass. For simplicity, the code below replaces the actual arithmetic computations generated by thequant.dcastandquant.qcastrewrite patterns with function calls to@dequantizeand@quantize, respectively.
func.func private @dequantize(%arg0: tensor<3xi8>) -> tensor<3xf32>
func.func private @quantize(%arg0: tensor<3xf32>) -> tensor<3xi8>
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>
func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {
%arg0_i8 = "quant.scast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xi8>
%arg0_f32 = func.call @dequantize(%arg0_i8) : (tensor<3xi8>) -> tensor<3xf32>
%arg1_i8 = "quant.scast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xi8>
%arg1_f32 = func.call @dequantize(%arg1_i8) : (tensor<3xi8>) -> tensor<3xf32>
%product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>
%arg2_i8 = "quant.scast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xi8>
%arg2_f32 = func.call @dequantize(%arg2_i8) : (tensor<3xi8>) -> tensor<3xf32>
%sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>
%sum_i8 = func.call @quantize(%sum_f32) : (tensor<3xf32>) -> tensor<3xi8>
%sum = "quant.scast"(%sum_i8) : (tensor<3xi8>) -> tensor<3x!qalias>
return %sum : tensor<3x!qalias>
}
- At this point, all uses of quantized types are limited to the function header and the newly introduced
quant.scastops. The latter guarantee the syntactic integrity of the code at this intermediate state of the workflow, but they perform no meaningful data conversions. Therefore, there is no remaining op in the code that relies on the quantized types to fine-tune its behavior. In preparation to eliminate the remaining occurrences of quantized types, the function header is now transformed with pass--strip-func-quant-types.
func.func private @dequantize(%arg0: tensor<3xi8>) -> tensor<3xf32>
func.func private @quantize(%arg0: tensor<3xf32>) -> tensor<3xi8>
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>
func.func @multiply_add(%arg0_stripped: tensor<3xi8>, %arg1_stripped: tensor<3xi8>, %arg2_stripped: tensor<3xi8>) -> tensor<3xi8> {
%arg0 = "quant.scast"(%arg0_stripped) : (tensor<3xi8>) -> tensor<3x!qalias>
%arg1 = "quant.scast"(%arg1_stripped) : (tensor<3xi8>) -> tensor<3x!qalias>
%arg2 = "quant.scast"(%arg2_stripped) : (tensor<3xi8>) -> tensor<3x!qalias>
%arg0_i8 = "quant.scast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xi8>
%arg0_f32 = func.call @dequantize(%arg0_i8) : (tensor<3xi8>) -> tensor<3xf32>
%arg1_i8 = "quant.scast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xi8>
%arg1_f32 = func.call @dequantize(%arg1_i8) : (tensor<3xi8>) -> tensor<3xf32>
%product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>
%arg2_i8 = "quant.scast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xi8>
%arg2_f32 = func.call @dequantize(%arg2_i8) : (tensor<3xi8>) -> tensor<3xf32>
%sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>
%sum_i8 = func.call @quantize(%sum_f32) : (tensor<3xf32>) -> tensor<3xi8>
%sum = "quant.scast"(%sum_i8) : (tensor<3xi8>) -> tensor<3x!qalias>
%sum_stripped = "quant.scast"(%sum) : (tensor<3x!qalias>) -> tensor<3xi8>
return %sum_stripped : tensor<3xi8>
}
- Quantized types only occur in
quant.scastops at this point, all of which should cancel each other out through the application of canonicalization patterns and DCE. This will always be the case as long as the original application of pass--convert-ml-to-arithpass is able to handle all occurrences of quantized types correctly.
func.func private @dequantize(%arg0: tensor<3xi8>) -> tensor<3xf32>
func.func private @quantize(%arg0: tensor<3xf32>) -> tensor<3xi8>
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>
func.func @multiply_add(%arg0_stripped: tensor<3xi8>, %arg1_stripped: tensor<3xi8>, %arg2_stripped: tensor<3xi8>) -> tensor<3xi8> {
%arg0_f32 = func.call @dequantize(%arg0_stripped) : (tensor<3xi8>) -> tensor<3xf32>
%arg1_f32 = func.call @dequantize(%arg1_stripped) : (tensor<3xi8>) -> tensor<3xf32>
%product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>
%arg2_f32 = func.call @dequantize(%arg2_stripped) : (tensor<3xi8>) -> tensor<3xf32>
%sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>
%sum_i8 = func.call @quantize(%sum_f32) : (tensor<3xf32>) -> tensor<3xi8>
return %sum_i8: tensor<3xi8>
}
Common subexpression elimination (CSE)
The workflow above benefits from the application of a CSE and DCE sequence right before the application of pass --lower-quant-ops. This gets rid of redundant quant.qcast and quant.dcast ops, avoiding the processing cost of their rewrite patterns, as well as the possibly significant additional size of the resulting IR. By simply labeling ops quant.qcast and quant.dcast as side-effect-free (pure) in their specification, this scenario is automatically handled by the standard CSE and DCE passes in MLIR. The following example illustrates this scenario.
- In the following input IR, function
@divisioncomputes both the quotient and remainder of the given quantized inputs through the use of hypothetical opsml.divandml.mod.
!qalias = !quant.uniform<i8:f32, 2.0>
func.func @division(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> (tensor<3x!qalias>, tensor<3x!qalias>) {
%quotient = "ml.div"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
%remainder = "ml.mod"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
return %quotient, %remainder : tensor<3x!qalias>, tensor<3x!qalias>
}
- Let us assume that user-defined pass
--convert-ml-to-arithhandles quantized inputs forml.divandml.modby dequantizing them, then processing floating-point operands witharith.divfandarith.remfops, and finally quantizing their results back into the original result types. The resulting IR includes 2 pairs ofquant.dcastops, both of which dequantize the same input arguments.
!qalias = !quant.uniform<i8:f32, 2.0>
func.func @division(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> (tensor<3x!qalias>, tensor<3x!qalias>) {
%arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
%arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
%quotient_f32 = arith.divf %arg0_f32, %arg1_f32 : tensor<3xf32>
%quotient = "quant.qcast"(%quotient_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
%arg0_f32_1 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
%arg1_f32_1 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
%remainder_f32 = arith.remf %arg0_f32_1, %arg1_f32_1 : tensor<3xf32>
%remainder = "quant.qcast"(%remainder_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
return %quotient, %remainder : tensor<3x!qalias>, tensor<3x!qalias>
}
- CSE fixes
arith.remfto consume the originally computed%arg0_f32and%arg1_f32dequantized inputs, which turns the producers of%arg0_f32_1and%arg1_f32_1into dead code. DCE then gets rid of the second pair ofquant.dcastops.
!qalias = !quant.uniform<i8:f32, 2.0>
func.func @division(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> (tensor<3x!qalias>, tensor<3x!qalias>) {
%arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
%arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
%quotient_f32 = arith.divf %arg0_f32, %arg1_f32 : tensor<3xf32>
%quotient = "quant.qcast"(%quotient_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
%remainder_f32 = arith.remf %arg0_f32, %arg1_f32 : tensor<3xf32>
%remainder = "quant.qcast"(%remainder_f32) : (tensor<3xf32>) -> tensor<3x!qalias>
return %quotient, %remainder : tensor<3x!qalias>, tensor<3x!qalias>
}
Application
A possible application of the presented workflow is the Tensorflow Lite (tfl) lowering pipeline. While the Tensorflow repository offers a quite advanced TFL-to-TOSA transform pipeline, support for quantized types is currently limited by multiple factors, including semantic restrictions of the tosa dialect or the lack of dedicated lowering logic emitting integer arithmetic.
Lowering a Tensorflow Lite model
To better illustrate the challenges of supporting quantized types in the tfl dialect, the following example shows the steps involved in lowering a Tensorflow Lite model using the TFL-to-TOSA pipeline.
- In this input IR, function
@predictsimply adds two input quantized values and returns their sum, also using a quantized format.
!qalias = !quant.uniform<i8:f32, 2.0>
func.func @predict(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> tensor<3x!qalias> {
%result = "tfl.add"(%arg0, %arg1) { fused_activation_function = "NONE“ } : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
return %result : tensor<3x!qalias>
}
- Pass
--tfl-to-tosa-pipelinein Tensorflow tooltflite-optlowers this code to thetosadialect. It is intended to honor the semantics oftflops with quantized input operands, and it does so successfully in this example. For thetfl.addop, this involves rescaling input operands to a common intermediate quantized type, performing an addition using integer arithmetic, and rescaling the result back to the original quantized type.
!qalias = !quant.uniform<i8:f32, 2.0>
func.func @predict(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> tensor<3x!qalias> {
%0 = tosa.rescale %arg0 { ... } : (tensor<3x!qalias>) -> tensor<3xi32>
%1 = tosa.rescale %arg1 { ... } : (tensor<3x!qalias>) -> tensor<3xi32>
%2 = tosa.add %0, %1 : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
%3 = tosa.rescale %2 { ... } : (tensor<3xi32>) -> tensor<3x!qalias>
return %3 : tensor<3x!qalias>
}
- In the code above, quantized types are exclusively handled and manipulated by
tosa.rescaleoperations, while other generated ops operate directly on the associated integer storage types. Thetosa.rescaleop expresses its behavior fully through explicit attributes, without relying on implicit information contained in the quantized types of its operands or result. Therefore, quantized types no longer provide critical semantic information in the IR, and may be cast into their associated storage types. This can be accomplished by running the above pass with additional option--tfl-to-tosa-pipeline=target-compilation-backend. All occurrences of type!quant.uniform<i8:f32, 2.0>are converted toi8, which allows the remainder of the lowering pipeline to interpret all computations are pure integer arithmetic.
func.func @predict(%arg0: tensor<3xi8>, %arg1: tensor<3xi8>) -> tensor<3xi8> {
%0 = tosa.rescale %arg0 { ... } : (tensor<3xi8>) -> tensor<3xi32>
%1 = tosa.rescale %arg1 { ... } : (tensor<3xi8>) -> tensor<3xi32>
%2 = tosa.add %0, %1 : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
%3 = tosa.rescale %2 { ... } : (tensor<3xi32>) -> tensor<3xi8>
return %3 : tensor<3xi8>
}
Limitations while lowering quantized types
While the TFL-to-TOSA pipeline handles a variety of cases of ops with quantized inputs correctly, it currently fails to lower less frequently occurring cases correctly, such as certain storage type bit widths or per-axis quantization. The behavior currently observed for unsupported quantized is either refusing to lower the affected tfl op altogether, or even worse, substituting it with a sequence of ops that yield silent wrong answers. Below are some examples.
- Operation
tfl.log_softmaxfails to lower when its inputs are quantized.
!qalias = !quant.uniform<i8:f32, 0.5>
func.func @main(%arg0: tensor<2x3x!qalias>) -> tensor<2x3x!qalias> {
%0 = "tfl.log_softmax"(%arg0) : (tensor<2x3x!qalias>) -> tensor<2x3x!qalias>
return %0 : tensor<2x3x!qalias>
}
test.mlir:3:8: error: 'tfl.log_softmax' op : illegal op still exists
%0 = "tfl.log_softmax"(%arg0) : (tensor<2x3x!qalias>) -> tensor<2x3x!qalias>
^
test.mlir:3:8: note: see current operation: %0 = "tfl.log_softmax"(%arg0) : (tensor<2x3xi8>) -> tensor<2x3xi8>
test.mlir:2:1: error: The following illegal operations still remain:
tfl.log_softmax (count: 1)
func.func @main(%arg0: tensor<2x3x!qalias>) -> tensor<2x3x!qalias> {
^
- Operation
tfl.padproduces invalid code when lowered to thetosadialect when its first input argument is quantized.
!qalias = !quant.uniform<i8:f32, 0.5>
func.func @main(%arg0: tensor<2x2x!qalias>) -> tensor<4x6x!qalias> {
%padding = "tfl.pseudo_const"() {value = dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%result = "tfl.pad"(%arg0, %padding) : (tensor<2x2x!qalias>, tensor<2x2xi32>) -> tensor<4x6x!qalias>
return %result : tensor<4x6x!qalias>
}
func.func @main(%arg0: tensor<2x2xi8>) -> tensor<4x6xi8> {
%0 = "tosa.const"() <{value = dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
%1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
%2 = tosa.pad %arg0, %0, %1 {quantization_info = #tosa.pad_quant<input_zp = 0>} : (tensor<2x2xi8>, tensor<2x2xi32>, tensor<i8>) -> tensor<4x6xi8>
return %2 : tensor<4x6xi8>
}
The examples above are simplified use cases of real ops found in pre-trained models available on Kaggle . Similar errors have been encountered for other ops, such as tfl.l2_normalization, tfl.split, or tfl.arg_min. While it is conceivable to address these issues individually, it is much harder to guarantee that all valid quantized type combinations are properly handled, especially in less frequently occurring ops not yet encountered in publicly available pre-trained models.
Overcoming limited quantization support
The features proposed in this document may help address unsupported parameter combinations of the !quant.uniform type by adopting the following general structure in a rewrite pattern for a tfl op:
-
Handle custom cases of quantized input operands with an algorithm that operates directly on the storage type of the quantized inputs using integer arithmetic.
-
If input arguments are quantized but their specific type parameters were not considered above, dequantize input arguments by converting them to floating-point values with
quant.dcastops. -
Handle the general case that emits code for floating-point inputs.
-
If the original
tflop uses a quantized type, quantize the resulting floating-point value with aquant.qcastop.
An implementation may decide to offer a dequantization-based fallback strategy as an optional feature through a command-line modifier for the lowering transform pipeline.