[RFC] Improvements in the 'quant' dialect

Introduction

This RFC proposes a refinement of the MLIR quant dialect and its tooling. The proposal includes:

  • An improved specification for ops quant.dcast, quant.qcast, and quant.scast with custom assembly format, clear semantics, strict syntactic verification, and canonicalization.

  • New pass --lower-quant-ops to lower ops quant.qcast and quant.dcast to other core dialects, with support for statically shaped, dynamically shaped, and unranked tensors, as well as per-layer and per-channel quantization.

  • New pass --strip-func-quant-types to eliminate quantization information from function headers.

  • A workflow example illustrating how to integrate the proposed passes and canonicalization patterns in a transform pipeline that lowers a generic machine learning dialect with quantization support.

  • A specific application in the domain of the Tensorflow Lite (tfl) dialect, based on the design of a fallback strategy to lower currently unsupported quantized types.

  • A working implementation with comprehensive documentation and unit tests, available on this pull request.

The quant dialect

This section is aimed at describing the quant dialect according to the features proposed in this document. The content of this section is used as the TableGen documentation for the dialect in the associated pull request. It includes the following material:

  • An introduction to the quant dialect.

  • A thorough description of the !quant.uniform type as currently supported by the proposed lowering passes, including per-layer and per-channel quantization variants.

  • A section titled Per-axis Quantization Integrity describes the requirements for using the !quant.uniform type with per-axis quantization. This section is referenced in the documentation for each operation to clarify that these syntactic requirements apply to all operations using per-axis quantization.

Introduction

The quant dialect offers a framework for defining and manipulating
quantized values. Central to this framework is the !quant.uniform data
type, used to represent quantized values. This dialect also provides a
suite of operations to handle and convert quantized values between their
original floating-point representations and the optimized, lower bit-width
integer representations. The quant dialect is instrumented with
transformation passes to lower these operations into other core MLIR
dialects, while also flattening all occurrences of quantized types into
their integer counterparts.

The !quant.uniform type

The quantization process establishes a relationship between two types of
values: an expressed value and a stored value. The former refers to the
floating-point representation used in an original machine learning model,
capturing the precise numerical characteristics needed for accurate
calculations. The latter is the simplified integer representation that
resides in memory after quantization. The !quant.uniform data type
encodes the necessary information for (lossy) round-trip conversion between
an expressed and a stored value.

The quant.uniform type has two variants: per-layer quantization and
per-channel (or per-axis) quantization. In per-layer quantization, the
quantization information affects an entire tensor uniformly. Conversely, in
per-channel quantization, the data type encodes the specific tensor axis
that serves as the channel and includes quantization information for each
individual channel within the tensor. Below are the specific syntactic and
semantic considerations for each modality.

Per-layer quantization

This is the general syntax of the !quant.uniform type representing
per-layer quantization:

`!quant.uniform` `<`
  storedType (`<` storageMin `:` storageMax `>`)? `:`
  expressedType `,`
  scale (`:` zeroPoint)?
`>`

The type contains the following parameters:

  • storedType: Integer type of the value stored in memory. This type
    conveys the bit width and signedness of the quantized stored value.
    Signed integer types are represented as 'i' bitWidth (e.g., i8),
    while unsigned integer types are represented as 'u' bitWidth (e.g.,
    u8).

  • storageMin, storageMax: Optional bounds for the stored value. If
    given, they must be within the range of storedType. If omitted, the
    entire range of storedType is allowed (e.g., -128...127 for i8 or
    0...255 for u8).

  • expressedType: Floating-point type of the value expressed by this
    quantized type.

  • scale: Floating-point value of type expressedType used in the
    conversion between stored and expressed values.

  • zeroPoint: Optional integer value of type storageType used in the
    conversion between stored and expressed values. If omitted, the default
    is 0.

Type conversions, rounding methods, and clamping actions aside, the
relationship between the expressed and stored values as encoded in a
quantized type is denoted by the following formula:

expressedValue = (storedValue - zeroPoint) * scale

Operations quant.qcast (quantize cast) and quant.dcast (dequantize
cast) can be used to quantize a floating-point value and dequantize a
stored value, respectively. See the documentation for these operations for
details on how the quantization and dequantization processes are influenced
by the !quant.uniform type parameters.

Here are some examples of the use of !quant.uniform with per-layer
quantization:

// An 8-bit signed integer type is used to represent a 32-bit float. No
// clamping information is provided, so the full [-128, 127] range is
// available. The scale is set to 3.0, and the zero point takes its default
// 0 value.
!quant.uniform<i8:f32, 3.0>

// A 16-bit unsigned integer type is used to represent a 32-bit float. Out
// of the 16 bits, only 10 are used, acoording to the 0..1023 clamping
// range. The type sets the scale to 1.23 and the zero point to 512.
!quant.uniform<u16<0:1023>:f32, 1.23:512>

Per-channel quantization

The general syntax of the !quant.uniform type representing per-channel
quantization is as follows:

`!quant.uniform` `<`
  storedType (`<` storageMin `:` storageMax `>`)? `:`
  expressedType `:`
  channelAxis `,`
  `{`
    scale0 (`:` zeroPoint0)? `,`
    scale1 (`:` zeroPoint1)? ...
  '}'
`>`

In this data type, there are multiple pairs of scale and zeroPoint
values. The channelAxis field represents the dimension of the containing
tensor acting as the channel. The size of the tensor along this dimension
is expected to match the number of provided scale-zeroPoint pairs, and
a given pair i applies to all elements in the tensor whose index along
dimension channelAxis is i. A quantized data type using per-channel
quantization is always expected to be contained within a tensor type.

Here are some examples:

// A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
// floats. Dimension 1 of the tensor acts as the channel dimension. Its
// size 3 matches the number of provided scale values. Tensor elemenets at
// positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
// 5.0, respectively.
tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>

// A 2D dynamically sized tensor contains 16-bit unsigned integers
// representing 32-bit floats. Dimension 0 of the tensor acts as the
// channel dimension. Since 2 scale and zero-point values are provided, the
// size of dimension 0 is expected to be 2 at runtime. Tensor elements
// [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale
// 3.0 and zero point 20.
tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>

Per-axis quantization integrity

When type !quant.uniform contains per-axis quantization information, the
rules below are enforced. These rules guarantee that the quantization
information encoded in the data type is applicable to the context in which
the quantized type is used. For efficiency, these rules are actively
enforced by the verifiers of quant dialect ops, but they must be
respected in any context in which the !quant.uniform data type is used,
such as the header of a func.func op, or the input of an arithmetic
operation.

  • A quantized type with per-channel quantization information must be the
    element type of a tensor container type, and may not occur directly as
    the data type of a scalar value.
// Incorrect. Type !quant.uniform specifies per-channel quantization for a
// scalar type.
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>

// Correct. Type `!quant.uniform` with per-channel quantization is wrapped
// in a `tensor` type.
%result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
  • If the tensor containing the !quant.uniform type is ranked, its rank
    must be greater than the channel axis specified in the quantized type.
// Incorrect. The tensor rank (2) is not greater than the channel axis in
// the quantized type (3).
%result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>

// Correct. The tensor rank (2) is now greater than the channel axis (1):
%result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
  • If the axis dimension in the containing tensor is static, its size must
    be equal to the number of scales present in the quantized type.
// Incorrect. The channel axis is 1, and the size of dimension 1 in the
// containing tensor is 3. However, there are 4 scale values present in the
// quantized type.
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>

// Correct. The quantized type now includes 3 scale values, matching the
// size of dimension 1 of the result tensor.
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>

Operation quant.dcast

This section proposes a thorough specification for operation quant.dcast, with updated syntax, semantics, and canonicalization. This documentation is included in the TableGen operation definition on the associated pull request.

Syntax

The following custom assembly format is introduced:

operation ::= `quant.dcast` $input attr-dict `:` type($input) `to` type($result)

Semantics

Convert an input quantized value into its expressed floating-point value.
The dequantization process consists of the following steps:

def dequantize(quantizedValue: quantizedType) -> expressedType:
    storedValue = reinterpretCast(quantizedValue, storageType)
    storedValueFloat = convertIntToFloat(storedValue, expressedType)
    zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
    expressedValue = (storedValueFloat - zeroPointFloat) * scale
    return expressedValue

Here, storageType, expressedType, scale, and zeroPoint are obtained
from the corresponding parameters encoded in quantizedType. For
per-channel quantization, the appropriate scale and zeroPoint values
are used for each tensor element computation according to the channel the
element belongs to.

Verification

The operation must satisfy the following syntactic constraints:

  • Operand input must be a scalar or tensor of type !quant.uniform.

  • The result type must be a floating-point scalar or tensor.

  • The expressedType parameter of the !quant.uniform type of the input
    must match the floating-point type of the result.

  • The operand and result types must be both scalars or both tensors. If
    tensors, they must be both ranked or both unranked. If ranked, both must
    have the same shape, including matching static and dynamic dimensions.

  • If the operand uses per-channel quantization, its !quant.uniform type
    must adhere to the Per-axis quantization
    integrity
    guidelines.

Examples

  • Dequantize a scalar quantized value
%result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
  • Dequantize a dynamically shaped tensor of quantized values
%result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
  • Dequantize an unranked tensor using per-axis quantization information
%result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>

Canonicalization

If the operand of a quant.dcast op is produced by a quant.qcast op and the operand type of quant.qcast matches the result type of quant.dcast , the value produced by quant.dcast is replaced with the operand of quant.qcast . In the likely case that quant.dcast is the only consumer of quant.qcast , this transformation renders quant.qcast dead code, allowing for further simplification.

  • Input IR
%1 = quant.qcast %0 : f32 to !quant.uniform<i8:f32, 2.0>
%2 = quant.dcast %1 : !quant.uniform<i8:f32, 2.0> to f32
  • Output IR
%1 = quant.qcast %0 : f32 to !quant.uniform<i8:f32, 2.0>  // Possibly dead code
%2 = %0  // I.e., uses of %2 replaced with %0

Notice that the application of this canonicalization pattern may produce code that yields slightly different numerical results due to the lossy nature of the quant.qcast op. While the input IR converts an input floating-point value into a lower-precision integer representation just to immediately reverse this process, the output IR lets the original floating-point value flow through to the following consumers. It is here assumed that in the context of a quantized model, such slight numerical mismatch is tolerable, and that avoiding the rounding error incurred by the unnecessary quantization of an intermediate result is, in fact, desirable.

Operation quant.qcast

This section proposes a thorough specification for operation quant.qcast, with updated syntax, semantics, and canonicalization. This documentation is included in the TableGen operation definition on the associated pull request.

Syntax

The following custom assembly format is introduced:

operation ::= `quant.qcast` $input attr-dict `:` type($input) `to` type($result)

Semantics

Convert a floating-point value to a quantized type. The quantization
process consists of the following steps:

def quantize(expressedValue: expressedType) -> quantizedType:
    zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
    scaledValue = expressedValue / scale
    storedValueFloat = scaledValue + zeroPointFloat
    storedValue = convertFloatToInt(storedValueFloat, storageType)
    storedValueClamped = clamp(storedValue, storageMin, storageMax)
    quantizedValue = reinterpretCast(storedValueClamped, quantizedType)
    return quantizedValue

Here, storageType, storageMin, storageMax, expressedType, scale,
and zeroPoint are obtained from the corresponding parameters encoded in
quantizedType. For per-channel quantization, the appropriate scale and
zeroPoint values are used for each tensor element computation according
to the channel the element belongs to.

Verification

The operation must satisfy the following syntactic constraints:

  • Operand input must be a floating-point scalar or tensor.

  • The result type must be a scalar or tensor of type !quant.uniform.

  • The expressedType parameter in the !quant.uniform type of the result
    must match the floating-point type of the input.

  • The operand and result types must be both scalars or both tensors. If
    tensors, they must be both ranked or both unranked. If ranked, both must
    have the same shape, including matching static and dynamic dimensions.

  • If the result uses per-channel quantization, its !quant.uniform type
    must adhere to the Per-axis quantization
    integrity
    guidelines.

Examples

  • Quantize a scalar floating-point value
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32, 2.0>
  • Quantize a dynamically shaped tensor of quantized values
%result = quant.qcast %input : tensor<?xf32> to tensor<?x!quant.uniform<i8:f32, 2.0>>
  • Quantize an unranked tensor using per-axis quantization information
%result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>

Canonicalization

If the operand of a quant.qcast op is produced by a quant.dcast op and the operand type of quant.dcast matches the result type of quant.qcast, the SSA value produced by quant.qcast is replaced with the operand of quant.dcast. In the likely case that quant.qcast was the only consumer of quant.dcast, this transformation renders quant.dcast dead code, allowing for further simplification.

  • Input IR
%1 = quant.dcast %0 : !quant.uniform<i8:f32, 2.0> to f32
%2 = quant.qcast %1 : f32 to !quant.uniform<i8:f32, 2.0>
  • Output IR
%1 = quant.dcast %0 : !quant.uniform<i8:f32, 2.0> to f32  // Possibly dead code
%2 = %0  // I.e., uses of %2 replaced with %0

Operation quant.scast

This section proposes a thorough specification for operation quant.scast, with updated syntax, semantics, and canonicalization. This documentation is included in the TableGen operation definition on the associated pull request.

Syntax

The following custom assembly format is introduced:

operation ::= `quant.scast` $input attr-dict `:` type($input) `to` type($result)

Semantics

Convert a value from a quantized type to the corresponding signless integer
storage type, or vice versa. This conversion simply involves a
reinterpretation of the input bits and does not involve any data
manipulation.

This operation is semantically equivalent to builtin.unrealized_conversion_cast. It shares the following properties:

  • It performs a bitwise cast with no data conversion.

  • There is no rewrite pattern that converts this operation into lower level dialects. It acts as a temporary type conversion of values handled by interacting transform passes (i.e., --lower-quant-ops and --strip-func-quant-types). After these passes complete, all occurrences of this operation should cancel each other out through the application of canonicalization patterns and dead code elimination.

Verification

The following syntactic restrictions must be met:

  • Operand input must be a scalar or tensor of a signless integer or
    !quant.uniform type.

  • The result must be a scalar or tensor of a signless integer or
    !quant.uniform type.

  • If the operand is a scalar or tensor of type integer, the result must be
    a scalar or tensor of type !quant.uniform, and vice versa.

  • The operand and result must be both scalars or both tensors. If tensors,
    they must be both ranked or both unranked. If ranked, both must have the
    same shape, including matching static and dynamic dimensions.

  • The width of the storageType parameter of the quantized type of the
    operand or result must match the width of the signless integer type of
    the operand or result.

  • If the operand or result uses per-channel quantization, its
    !quant.uniform type must adhere to the Per-axis quantization
    integrity
    guidelines.

Examples

  • Cast a scalar quantized value into its storage type
%result = quant.scast %input : !quant.uniform<i8:f32, 2.0> to i8
  • Cast a dynamically shaped tensor of quantized values into their storage type
%result = quant.scast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xi8>
  • Cast an unranked tensor of signless integers into a quantized type using per-channel quantization
%result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>

Canonicalization

In a sequence of 2 quant.scast ops where the second consumes the value produced by the first, the result of the second may be replaced with the operand of the first if their data types match. This conversion may render the first op dead.

  • Input IR
%1 = quant.scast %0 : i8 to !quant.uniform<i8:f32, 2.0>
%2 = quant.scast %1 : !quant.uniform<i8:f32, 2.0> to i8
  • Output IR
%1 = quant.scast %0 : i8 to !quant.uniform<i8:f32, 2.0>  // Possibly dead code
%2 = %0  // I.e., uses of %2 replaced with %0

Pass --lower-quant-ops

This pass lowers all instances of quant.dcast and quant.qcast within a function body. Although the pass marks these operations as illegal, it does not declare the entire quant dialect illegal, as the lowering patterns are expected to emit quant.scast operations.

The pass handles quant.dcast and quant.qcast operations similarly, but for simplicity, the examples below will focus solely on the lowering of quant.qcast. The pass differentiates between four cases, each resulting in significantly different code structures, which are detailed next.

Case 1. Per-layer quantization, scalar or ranked input

The IR pattern shown below is emitted when the quantized type of the lowered op uses per-layer quantization. A similar sequence of ops is generated regardless of whether the input is a scalar value or a ranked tensor.

  • Input IR
!qalias = !quant.uniform<i8<-8:7>:f32, 2.0:10>
func.func @f(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
  %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias>
  return %0 : tensor<3x5x!qalias>
}
  • Output IR
func.func @f(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {

  // Create scale tensor.
  // NOTE: All 'arith.constant' + 'tensor.splat' ops will be canonicalized into
  // a single 'arith.constant' for statically shaped tensors.
  %cst = arith.constant 2.000000e+00 : f32
  %splat = tensor.splat %cst : tensor<3x5xf32>

  // Divide by scale
  %0 = arith.divf %arg0, %splat : tensor<3x5xf32>

	// Create zero point float tensor
	%c10_i8 = arith.constant 10 : i8
  %splat_0 = tensor.splat %c10_i8 : tensor<3x5xi8>
  %1 = arith.sitofp %splat_0 : tensor<3x5xi8> to tensor<3x5xf32>

  // Add zero point
  %2 = arith.addf %0, %1 : tensor<3x5xf32>

  // Convert stored value to integer
  %3 = arith.fptosi %2 : tensor<3x5xf32> to tensor<3x5xi8>

  // Clamp stored value
  %c-8_i8 = arith.constant -8 : i8
  %c7_i8 = arith.constant 7 : i8
  %splat_1 = tensor.splat %c-8_i8 : tensor<3x5xi8>
  %splat_2 = tensor.splat %c7_i8 : tensor<3x5xi8>
  %4 = arith.maxsi %3, %splat_1 : tensor<3x5xi8>
  %5 = arith.minsi %4, %splat_2 : tensor<3x5xi8>

  // Cast stored value to quantized type
  %6 = quant.scast %5 : tensor<3x5xi8> to tensor<3x5x!qalias>
  return %6 : tensor<3x5x!qalias>
}

Case 2. Per-layer quantization, unranked input

The main challenge when dealing with unranked tensors in the emitted IR on the previous case is the generation of scale and zero-point tensor splats whose total size and dimensions match the operation input for type compatibility in arithmetic computations. The strategy used here to deal with unranked tensors is to first flatten them to 1D dynamically sized tensors, query their total size, generate 1D constant tensor splats for scale and zero point, run arithmetic computations on the 1D tensors, and convert the result back into an unranked tensor of the original shape.

Shape-related computations leverage the shape dialect with the goal of maximizing the opportunities for code simplification across IR patterns emerging from consecutive quantization ops. Additional canonicalization patterns have been recently added in the shape and tensor dialect to facilitate such simplifications (see New canonicalization patterns for shape.shape_of and tensor.reshape).

  • Input IR
!qalias = !quant.uniform<i8:f32, 2.0:10>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
  return %0 : tensor<*x!qalias>
}
  • Output IR
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {

  // Compute shape and size of input tensor
  %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
  %1 = shape.num_elements %0 : tensor<?xindex> -> index

  // Reshape input to 1D dynamically sized tensor
  %from_elements = tensor.from_elements %1 : tensor<1xindex>
  %reshape = tensor.reshape %arg0(%from_elements) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

  // Now we know how to lower a ranked tensor
  %cst = arith.constant 2.000000e+00 : f32
  %c10_i8 = arith.constant 10 : i8
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %reshape, %c0 : tensor<?xf32>
  %splat = tensor.splat %cst[%dim] : tensor<?xf32>
  %2 = arith.divf %reshape, %splat : tensor<?xf32>
  %splat_0 = tensor.splat %c10_i8[%dim] : tensor<?xi8>
  %3 = arith.sitofp %splat_0 : tensor<?xi8> to tensor<?xf32>
  %4 = arith.addf %2, %3 : tensor<?xf32>
  %5 = arith.fptosi %4 : tensor<?xf32> to tensor<?xi8>

  // Convert the result back into an unranked tensor
  %reshape_1 = tensor.reshape %5(%0) : (tensor<?xi8>, tensor<?xindex>) -> tensor<*xi8>

  %6 = quant.scast %reshape_1 : tensor<*xi8> to tensor<*x!qalias>
  return %6 : tensor<*x!qalias>
}

Case 3. Per-channel quantization, ranked input

In per-channel quantization, different scale and zero-point pairs apply to the different items of the input tensor in the dimension designated as the channel. This is accomplished through the use of a linalg.generic operation with its affine map attributes carefully designed to apply the correct scale and zero point to each element of the input tensor.

  • Input IR
!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
func.func @f(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
  %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias>
  return %0 : tensor<4x2x5x!qalias>
}
  • Output IR
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1)>

func.func @f(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {

  // Create tensors of scales and zero points
  %cst = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
  %cst_0 = arith.constant dense<0> : tensor<2xi8>

  // Traverse input, scales, zero-point, and output tensors
  %0 = tensor.empty() : tensor<4x2x5xi8>
  %1 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%0 : tensor<4x2x5xi8>) {
  ^bb0(%in: f32, %in_1: f32, %in_2: i8, %out: i8):

    // Divide by scale
    %3 = arith.divf %in, %in_1 : f32

    // Convert zero point to float
    %4 = arith.sitofp %in_2 : i8 to f32

    // Add zero point
    %5 = arith.addf %3, %4 : f32

    // Convert stored value to integer
    %6 = arith.fptosi %5 : f32 to i8

    // Clamp stored value
    %c-8_i8 = arith.constant -8 : i8
    %c7_i8 = arith.constant 7 : i8
    %7 = arith.maxsi %6, %c-8_i8 : i8
    %8 = arith.minsi %7, %c7_i8 : i8
    linalg.yield %8 : i8
  } -> tensor<4x2x5xi8>

  // Cast stored values into quantized type
  %2 = quant.scast %1 : tensor<4x2x5xi8> to tensor<4x2x5x!qalias>
  return %2 : tensor<4x2x5x!qalias>
}

Case 4. Per-channel quantization, unranked input

Handling unranked tensors and per-channel quantization poses an additional challenge. We need to determine the position of the channel within an unranked tensor without knowing the exact number of dimensions on either side. However, we can assume that the channel dimension exists in the input tensor, or else we would be dealing with malformed IR.

The strategy for dealing with unranked tensors involves flattening the input into a 3D tensor. Everything to the left of the channel dimension is collapsed into the left-most of the reshaped tensor. The channel dimension itself becomes the middle dimension of the reshaped tensor, and everything to the right of the channel dimension is collapsed into the right-most dimension of the reshaped tensor. The arithmetic operations related to quantization are performed on this reshaped 3D tensor. Finally, the result is reshaped back into the original unranked tensor.

  • Input IR
!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
  %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
  return %0 : tensor<*x!qalias>
}
  • Output IR
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1)>

func.func @f(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {

  // Save shape of original unranked tensor
  %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>

  // Calculate tensor size on the left of channel dimension
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %head, %tail = "shape.split_at"(%0, %c2) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
  %1 = shape.num_elements %head : tensor<?xindex> -> index

  // Calculate tensor size on the right of channel dimension
  %head_0, %tail_1 = "shape.split_at"(%0, %c3) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
  %2 = shape.num_elements %tail_1 : tensor<?xindex> -> index

  // Reshape input to 3D tensor
  %c3_2 = arith.constant 3 : index
  %from_elements = tensor.from_elements %1, %c3_2, %2 : tensor<3xindex>
  %reshape = tensor.reshape %arg0(%from_elements) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x3x?xf32>

  // Scale and zero-point tensors
  %cst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
  %cst_3 = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
  %c0 = arith.constant 0 : index

  // Initialize output tensor
  %dim = tensor.dim %reshape, %c0 : tensor<?x3x?xf32>
  %c2_4 = arith.constant 2 : index
  %dim_5 = tensor.dim %reshape, %c2_4 : tensor<?x3x?xf32>
  %3 = tensor.empty(%dim, %dim_5) : tensor<?x3x?xi8>

  // Quantize
  %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %cst, %cst_3 : tensor<?x3x?xf32>, tensor<3xf32>, tensor<3xi8>) outs(%3 : tensor<?x3x?xi8>) {
  ^bb0(%in: f32, %in_7: f32, %in_8: i8, %out: i8):
    %6 = arith.divf %in, %in_7 : f32
    %7 = arith.sitofp %in_8 : i8 to f32
    %8 = arith.addf %6, %7 : f32
    %9 = arith.fptosi %8 : f32 to i8
      linalg.yield %9 : i8
  } -> tensor<?x3x?xi8>

  // Reshape output to original unranked tensor
  %reshape_6 = tensor.reshape %4(%0) : (tensor<?x3x?xi8>, tensor<?xindex>) -> tensor<*xi8>
  %5 = quant.scast %reshape_6 : tensor<*xi8> to tensor<*x!qalias>
  return %5 : tensor<*x!qalias>
}

Pass --strip-func-quant-types

This module-level pass performs the following actions:

  • It modifies the signature of every top-level func.func op in the module by substituting every occurrence of the !quant.uniform data type with its storage type. This applied to both function declarations and definitions.

  • For each input argument in a function definition, a quant.scast op is inserted at the top of the function body, converting the new argument type into the original !quant.uniform type. All uses of the original function argument are replaced with the SSA value produced by the new quant.scast op.

  • Every occurrence of the return op in the body of the function is prepended with a quant.scast op for each return value of type !quant.uniform, converting the quantized type into its storage value. The occurrence of the quantized value in the return op is replaced with the SSA value produced by the quant.scast op.

  • All occurrences of func.call and func.call_indirect are adjusted to the new signature of the invoked functions, which now lack any occurrence of quantized values. For every quantized argument originally passed to a function call, an additional quant.scast is introduced, and its integer result is passed to the function call instead of the original quantized value. Similarly, a quant.scast op is also introduced for every quantized result generated by the function call, converting the integer type returned by the new function signature back into the original quantized type.

The following example illustrates the application of this pass on a function with 2 inputs and 1 output, all of them of a quantized type.

  • Input IR
!qalias = !quant.uniform<i8:f32, 1.0>

func.func @predict(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> tensor<3x!qalias> {
  %sum = "ml.add"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
  return %sum : tensor<3x!qalias>
}
  • Output IR
!qalias = !quant.uniform<i8:f32, 1.0>

func.func @predict(%arg0_stripped: tensor<3xi8>, %arg1_stripped: tensor<3xi8>) -> tensor<3xi8> {

  // Conversion of function arguments
  %arg0 = "quant.scast"(%arg0_stripped): (tensor<3xi8>) -> tensor<3x!qalias>
  %arg1 = "quant.scast"(%arg1_stripped): (tensor<3xi8>) -> tensor<3x!qalias>

  // Function body
  %sum = "ml.add"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>

  // Conversion of return values
  %sum_stripped = "quant.scast"(%sum): (tensor<3x!qalias>) -> tensor<3xi8>
  return %sum_stripped : tensor<3xi8>
}

Workflow example

Consider a hypothetical user-defined machine learning dialect ml supporting quantized types. Consider also a user-defined pass --convert-ml-to-arith in charge of lowering ml ops into the arith dialect, while accounting for the possibility for op arguments to be of various -possibly different- quantized types.

  • In the following input IR, function @multiply_add uses hypothetical ops ml.mul and ml.add to process the 1D quantized tensors supplied as function arguments, while the result is provided as the function’s return value.
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>

func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {
  %product = "ml.mul"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
  %sum = "ml.add"(%product, %arg2) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
  return %sum : tensor<3x!qalias>
}
  • Pass --convert-ml-to-arith is expected to handle any combination of quantized types accepted by the ml op verifiers. Let us assume, however, that per-channel quantization, as used in this example, is currently not specifically dealt with by the pass, either because it is a temporarily unsupported feature, or because the infrequent occurrence of such case makes the benefits of an optimized lowering based on quantized integer arithmetic not worth the costly implementation effort. The pass may then rely on a dequantization approach for unsupported corner cases, producing the code below. Each lowering pattern in the pass dequantizes input operands, applies the default floating-point lowering pattern for the ml op, and quantizes the floating-point result back to match the original result type.
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>

func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {

  %arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
  %arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
  %product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>
  %product = "quant.qcast"(%product_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  %product_f32_1 = "quant.dcast"(%product) : (tensor<3x!qalias>) -> tensor<3xf32>
  %arg2_f32 = "quant.dcast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xf32>
  %sum_f32 = arith.addf %product_f32_1, %arg2_f32 : tensor<3xf32>
  %sum = "quant.qcast"(%sum_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  return %sum : tensor<3x!qalias>
}
  • Applying lowering patterns on individual ops produces redundant conversions that can be optimized before applying further transformations. In the code above, the quantization of the result of arith.mulf is immediately dequantized in preparation for the following arith.addf. Applying a canonicalization pass at this point eliminates this redundancy, leading to the following code.
!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>

func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {

  %arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
  %arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
  %product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>

  %arg2_f32 = "quant.dcast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xf32>
  %sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>
  %sum = "quant.qcast"(%sum_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  return %sum : tensor<3x!qalias>
}
  • The inserted quant ops may now be lowered with the proposed --lower-quant-ops pass. For simplicity, the code below replaces the actual arithmetic computations generated by the quant.dcast and quant.qcast rewrite patterns with function calls to @dequantize and @quantize, respectively.
func.func private @dequantize(%arg0: tensor<3xi8>) -> tensor<3xf32>
func.func private @quantize(%arg0: tensor<3xf32>) -> tensor<3xi8>

!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>

func.func @multiply_add(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>, %arg2: tensor<3x!qalias>) -> tensor<3x!qalias> {

  %arg0_i8 = "quant.scast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xi8>
  %arg0_f32 = func.call @dequantize(%arg0_i8) : (tensor<3xi8>) -> tensor<3xf32>

  %arg1_i8 = "quant.scast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xi8>
  %arg1_f32 = func.call @dequantize(%arg1_i8) : (tensor<3xi8>) -> tensor<3xf32>

  %product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>

  %arg2_i8 = "quant.scast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xi8>
  %arg2_f32 = func.call @dequantize(%arg2_i8) : (tensor<3xi8>) -> tensor<3xf32>

  %sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>

  %sum_i8 = func.call @quantize(%sum_f32) : (tensor<3xf32>) -> tensor<3xi8>
  %sum = "quant.scast"(%sum_i8) : (tensor<3xi8>) -> tensor<3x!qalias>

  return %sum : tensor<3x!qalias>
}
  • At this point, all uses of quantized types are limited to the function header and the newly introduced quant.scast ops. The latter guarantee the syntactic integrity of the code at this intermediate state of the workflow, but they perform no meaningful data conversions. Therefore, there is no remaining op in the code that relies on the quantized types to fine-tune its behavior. In preparation to eliminate the remaining occurrences of quantized types, the function header is now transformed with pass --strip-func-quant-types.
func.func private @dequantize(%arg0: tensor<3xi8>) -> tensor<3xf32>
func.func private @quantize(%arg0: tensor<3xf32>) -> tensor<3xi8>

!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>

func.func @multiply_add(%arg0_stripped: tensor<3xi8>, %arg1_stripped: tensor<3xi8>, %arg2_stripped: tensor<3xi8>) -> tensor<3xi8> {
  %arg0 = "quant.scast"(%arg0_stripped) : (tensor<3xi8>) -> tensor<3x!qalias>
  %arg1 = "quant.scast"(%arg1_stripped) : (tensor<3xi8>) -> tensor<3x!qalias>
  %arg2 = "quant.scast"(%arg2_stripped) : (tensor<3xi8>) -> tensor<3x!qalias>

  %arg0_i8 = "quant.scast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xi8>
  %arg0_f32 = func.call @dequantize(%arg0_i8) : (tensor<3xi8>) -> tensor<3xf32>

  %arg1_i8 = "quant.scast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xi8>
  %arg1_f32 = func.call @dequantize(%arg1_i8) : (tensor<3xi8>) -> tensor<3xf32>

  %product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>

  %arg2_i8 = "quant.scast"(%arg2) : (tensor<3x!qalias>) -> tensor<3xi8>
  %arg2_f32 = func.call @dequantize(%arg2_i8) : (tensor<3xi8>) -> tensor<3xf32>

  %sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>

  %sum_i8 = func.call @quantize(%sum_f32) : (tensor<3xf32>) -> tensor<3xi8>
  %sum = "quant.scast"(%sum_i8) : (tensor<3xi8>) -> tensor<3x!qalias>

  %sum_stripped = "quant.scast"(%sum) : (tensor<3x!qalias>) -> tensor<3xi8>
  return %sum_stripped : tensor<3xi8>
}
  • Quantized types only occur inquant.scast ops at this point, all of which should cancel each other out through the application of canonicalization patterns and DCE. This will always be the case as long as the original application of pass --convert-ml-to-arith pass is able to handle all occurrences of quantized types correctly.
func.func private @dequantize(%arg0: tensor<3xi8>) -> tensor<3xf32>
func.func private @quantize(%arg0: tensor<3xf32>) -> tensor<3xi8>

!qalias = !quant.uniform<i8:f32:0, {2.0, 3.0, 4.0}>

func.func @multiply_add(%arg0_stripped: tensor<3xi8>, %arg1_stripped: tensor<3xi8>, %arg2_stripped: tensor<3xi8>) -> tensor<3xi8> {
  %arg0_f32 = func.call @dequantize(%arg0_stripped) : (tensor<3xi8>) -> tensor<3xf32>
  %arg1_f32 = func.call @dequantize(%arg1_stripped) : (tensor<3xi8>) -> tensor<3xf32>
  %product_f32 = arith.mulf %arg0_f32, %arg1_f32 : tensor<3xf32>

  %arg2_f32 = func.call @dequantize(%arg2_stripped) : (tensor<3xi8>) -> tensor<3xf32>
  %sum_f32 = arith.addf %product_f32, %arg2_f32 : tensor<3xf32>

  %sum_i8 = func.call @quantize(%sum_f32) : (tensor<3xf32>) -> tensor<3xi8>
  return %sum_i8: tensor<3xi8>
}

Common subexpression elimination (CSE)

The workflow above benefits from the application of a CSE and DCE sequence right before the application of pass --lower-quant-ops. This gets rid of redundant quant.qcast and quant.dcast ops, avoiding the processing cost of their rewrite patterns, as well as the possibly significant additional size of the resulting IR. By simply labeling ops quant.qcast and quant.dcast as side-effect-free (pure) in their specification, this scenario is automatically handled by the standard CSE and DCE passes in MLIR. The following example illustrates this scenario.

  • In the following input IR, function @division computes both the quotient and remainder of the given quantized inputs through the use of hypothetical ops ml.div and ml.mod.
!qalias = !quant.uniform<i8:f32, 2.0>

func.func @division(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> (tensor<3x!qalias>, tensor<3x!qalias>) {
  %quotient = "ml.div"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
  %remainder = "ml.mod"(%arg0, %arg1) : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
  return %quotient, %remainder : tensor<3x!qalias>, tensor<3x!qalias>
}
  • Let us assume that user-defined pass --convert-ml-to-arith handles quantized inputs for ml.div and ml.mod by dequantizing them, then processing floating-point operands with arith.divf and arith.remf ops, and finally quantizing their results back into the original result types. The resulting IR includes 2 pairs of quant.dcast ops, both of which dequantize the same input arguments.
!qalias = !quant.uniform<i8:f32, 2.0>

func.func @division(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> (tensor<3x!qalias>, tensor<3x!qalias>) {
  %arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
  %arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
  %quotient_f32 = arith.divf %arg0_f32, %arg1_f32 : tensor<3xf32>
  %quotient = "quant.qcast"(%quotient_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  %arg0_f32_1 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
  %arg1_f32_1 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
  %remainder_f32 = arith.remf %arg0_f32_1, %arg1_f32_1 : tensor<3xf32>
  %remainder = "quant.qcast"(%remainder_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  return %quotient, %remainder : tensor<3x!qalias>, tensor<3x!qalias>
}
  • CSE fixes arith.remf to consume the originally computed %arg0_f32 and %arg1_f32 dequantized inputs, which turns the producers of %arg0_f32_1 and %arg1_f32_1 into dead code. DCE then gets rid of the second pair of quant.dcast ops.
!qalias = !quant.uniform<i8:f32, 2.0>

func.func @division(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> (tensor<3x!qalias>, tensor<3x!qalias>) {
  %arg0_f32 = "quant.dcast"(%arg0) : (tensor<3x!qalias>) -> tensor<3xf32>
  %arg1_f32 = "quant.dcast"(%arg1) : (tensor<3x!qalias>) -> tensor<3xf32>
  %quotient_f32 = arith.divf %arg0_f32, %arg1_f32 : tensor<3xf32>
  %quotient = "quant.qcast"(%quotient_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  %remainder_f32 = arith.remf %arg0_f32, %arg1_f32 : tensor<3xf32>
  %remainder = "quant.qcast"(%remainder_f32) : (tensor<3xf32>) -> tensor<3x!qalias>

  return %quotient, %remainder : tensor<3x!qalias>, tensor<3x!qalias>
}

Application

A possible application of the presented workflow is the Tensorflow Lite (tfl) lowering pipeline. While the Tensorflow repository offers a quite advanced TFL-to-TOSA transform pipeline, support for quantized types is currently limited by multiple factors, including semantic restrictions of the tosa dialect or the lack of dedicated lowering logic emitting integer arithmetic.

Lowering a Tensorflow Lite model

To better illustrate the challenges of supporting quantized types in the tfl dialect, the following example shows the steps involved in lowering a Tensorflow Lite model using the TFL-to-TOSA pipeline.

  • In this input IR, function @predict simply adds two input quantized values and returns their sum, also using a quantized format.
!qalias = !quant.uniform<i8:f32, 2.0>

func.func @predict(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> tensor<3x!qalias> {
  %result = "tfl.add"(%arg0, %arg1) { fused_activation_function = "NONE“ } : (tensor<3x!qalias>, tensor<3x!qalias>) -> tensor<3x!qalias>
  return %result : tensor<3x!qalias>
}
  • Pass --tfl-to-tosa-pipeline in Tensorflow tool tflite-opt lowers this code to the tosa dialect. It is intended to honor the semantics of tfl ops with quantized input operands, and it does so successfully in this example. For the tfl.add op, this involves rescaling input operands to a common intermediate quantized type, performing an addition using integer arithmetic, and rescaling the result back to the original quantized type.
!qalias = !quant.uniform<i8:f32, 2.0>

func.func @predict(%arg0: tensor<3x!qalias>, %arg1: tensor<3x!qalias>) -> tensor<3x!qalias> {
  %0 = tosa.rescale %arg0 { ... } : (tensor<3x!qalias>) -> tensor<3xi32>
  %1 = tosa.rescale %arg1 { ...  } : (tensor<3x!qalias>) -> tensor<3xi32>
  %2 = tosa.add %0, %1 : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
  %3 = tosa.rescale %2 { ... } : (tensor<3xi32>) -> tensor<3x!qalias>
  return %3 : tensor<3x!qalias>
}
  • In the code above, quantized types are exclusively handled and manipulated by tosa.rescale operations, while other generated ops operate directly on the associated integer storage types. The tosa.rescale op expresses its behavior fully through explicit attributes, without relying on implicit information contained in the quantized types of its operands or result. Therefore, quantized types no longer provide critical semantic information in the IR, and may be cast into their associated storage types. This can be accomplished by running the above pass with additional option --tfl-to-tosa-pipeline=target-compilation-backend. All occurrences of type !quant.uniform<i8:f32, 2.0> are converted to i8, which allows the remainder of the lowering pipeline to interpret all computations are pure integer arithmetic.
func.func @predict(%arg0: tensor<3xi8>, %arg1: tensor<3xi8>) -> tensor<3xi8> {
  %0 = tosa.rescale %arg0 { ... } : (tensor<3xi8>) -> tensor<3xi32>
  %1 = tosa.rescale %arg1 { ... } : (tensor<3xi8>) -> tensor<3xi32>
  %2 = tosa.add %0, %1 : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
  %3 = tosa.rescale %2 { ... } : (tensor<3xi32>) -> tensor<3xi8>
  return %3 : tensor<3xi8>
}

Limitations while lowering quantized types

While the TFL-to-TOSA pipeline handles a variety of cases of ops with quantized inputs correctly, it currently fails to lower less frequently occurring cases correctly, such as certain storage type bit widths or per-axis quantization. The behavior currently observed for unsupported quantized is either refusing to lower the affected tfl op altogether, or even worse, substituting it with a sequence of ops that yield silent wrong answers. Below are some examples.

  • Operation tfl.log_softmax fails to lower when its inputs are quantized.
!qalias = !quant.uniform<i8:f32, 0.5>
func.func @main(%arg0: tensor<2x3x!qalias>) -> tensor<2x3x!qalias> {
  %0 = "tfl.log_softmax"(%arg0) : (tensor<2x3x!qalias>) -> tensor<2x3x!qalias>
  return %0 : tensor<2x3x!qalias>
}
test.mlir:3:8: error: 'tfl.log_softmax' op : illegal op still exists
  %0 = "tfl.log_softmax"(%arg0) : (tensor<2x3x!qalias>) -> tensor<2x3x!qalias>
       ^
test.mlir:3:8: note: see current operation: %0 = "tfl.log_softmax"(%arg0) : (tensor<2x3xi8>) -> tensor<2x3xi8>
test.mlir:2:1: error: The following illegal operations still remain:
        tfl.log_softmax (count: 1)

func.func @main(%arg0: tensor<2x3x!qalias>) -> tensor<2x3x!qalias> {
^
  • Operation tfl.pad produces invalid code when lowered to the tosa dialect when its first input argument is quantized.
!qalias = !quant.uniform<i8:f32, 0.5>
func.func @main(%arg0: tensor<2x2x!qalias>) -> tensor<4x6x!qalias> {
  %padding = "tfl.pseudo_const"() {value = dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
  %result = "tfl.pad"(%arg0, %padding) : (tensor<2x2x!qalias>, tensor<2x2xi32>) -> tensor<4x6x!qalias>
  return %result : tensor<4x6x!qalias>
}
func.func @main(%arg0: tensor<2x2xi8>) -> tensor<4x6xi8> {
  %0 = "tosa.const"() <{value = dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
  %1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
  %2 = tosa.pad %arg0, %0, %1 {quantization_info = #tosa.pad_quant<input_zp = 0>} : (tensor<2x2xi8>, tensor<2x2xi32>, tensor<i8>) -> tensor<4x6xi8>
  return %2 : tensor<4x6xi8>
}

The examples above are simplified use cases of real ops found in pre-trained models available on Kaggle . Similar errors have been encountered for other ops, such as tfl.l2_normalization, tfl.split, or tfl.arg_min. While it is conceivable to address these issues individually, it is much harder to guarantee that all valid quantized type combinations are properly handled, especially in less frequently occurring ops not yet encountered in publicly available pre-trained models.

Overcoming limited quantization support

The features proposed in this document may help address unsupported parameter combinations of the !quant.uniform type by adopting the following general structure in a rewrite pattern for a tfl op:

  • Handle custom cases of quantized input operands with an algorithm that operates directly on the storage type of the quantized inputs using integer arithmetic.

  • If input arguments are quantized but their specific type parameters were not considered above, dequantize input arguments by converting them to floating-point values with quant.dcast ops.

  • Handle the general case that emits code for floating-point inputs.

  • If the original tfl op uses a quantized type, quantize the resulting floating-point value with a quant.qcast op.

An implementation may decide to offer a dequantization-based fallback strategy as an optional feature through a command-line modifier for the lowering transform pipeline.

2 Likes

@stellaraccident @River707 @clattner @sjarus @jpienaar @abattery @vinograd47 @liufengdb @jingpu @mehdi_amini

Hi everyone. Just for some background, our MLIR team at MathWorks is involved in developing a Tensorflow Lite to MLIR lowering pipeline. While our pipeline reuses some of the existing TFL-to-TOSA rewrite patterns from the Tensorflow repository, our priority is providing near-full support for the tfl dialect specification rather than targeting a specific dialect. As an essential part of our support for quantized models, we rely on the types and ops currently present in the quant dialect, whose spec and related tooling we’re proposing to improve in this RFC.

This RFC seems to also address some of the concerns pointed out in RFC: Removing the ops from the `quant` dialect - #28 by stellaraccident regarding the lack of maintenance of the quant dialect. The proposed features are intended to keep the dialect in its current minimal state, while complementing it with the instrumentation necessary to make it fully functional in a relevant workflow.

Any feedback you may have will be greatly appreciated. I will be submitting a PR with the corresponding implementation within the next few days. Thank you!

1 Like

Thank you for the RFC. I’ll need to review it in detail after the US holidays, but it says many of the right things for this type of quantization.

I think if proceeding with this, I would take this chance to cede defacto code ownership of the dialect. My team doesn’t use this style of quantization anymore, and we don’t typically work with models where the quant params would work as static type parameters.

However, the TFL style here still has a lot of practical use and I certainly have no objection to supporting it well.

I think that the primary thing you are likely to run into is whether or not the lowering to in tree dialects gets you where you want to go. Might be ok for many TFL models. I’d try to avoid using arith for anything fancy on tensors.

In terms of alternatives, there are probably a couple, which might even be better for everyone:

  • Start an llvm incubator repo that includes a standalone fork of the TFL tooling (not tied to the tf repo), along with opinionated lowerings to an intermediate quant form like this and on to backend.
  • Build the above or in another existing project that is focused on ML frontends. I’d have to think on it a bit more, but torch-mlir has been headed that way, having recently picked up both Onnx interop and the various quant legalizations that come from there (ie. The “torch” prefix is showing its age as it has a pretty good foundation for general ml frontend pipelines and tooling).

In other words, this might be the point to just take the TFL mlir dialects/serializers and build them in a project that can be actually owned by the community who has an interest in this. If that were done, this mid level quant stuff would naturally roll into that project.

(Apologies for typos and edits - on Mobile)

To clarify why I think that a more holistic project focused on TFL interop would be a good thing: there is clearly demand for such a thing, you’ll find more people tuned into what this layer needs in something more directly adjacent to that, and this kind of thing typically benefits from a lot of CI and testing – which is something that the MLIR project proper is not setup to support. It’s been a shame for a long time that this use case has required stitching components from multiple projects vs being built well in one place. If it were me, and I had these use cases, I would take the chance to rectify that vs just trying to uplift this one piece in isolation.

My two cents. I’ll also support incremental evolution of the quant dialect in tree if there are a community of folks who will own it.

I think it’s great to see new energy and owner here. Not to say that boring here isn’t bad :slight_smile: I have not looked at details above, but just a comment about locations.

+1 on considering having an integration point with TFL interop, this is independent of whether you’d have sufficient upstream for end to end testing and usage - I’d like a testable and usable pipeline, I also appreciate that some of the tooling for TFLite may be difficult to incorporate without another spot. In ideal case, this may even just be the interop and CI so that most development need not happen there. -1 for torch-mlir as such a spot (don’t think that worked well for TCP, nor is it good visibility, nor encouraging layering, complicated update process with increased maintenance burden).

I would think biggest benefit being a simple interop repo, could pip install things for testing while most of compiler work could be in tree, visible and contributable by all. Its a sliding scale how much else is there, I’d like even more of the TFL parts outside TF repo. I’ve told folks such and drew diagrams as to how, but it’s not my decision to enable.

Yeah, I wasn’t particularly fond of this myself – we’re already just barely keeping the CI together. I mainly mentioned it because the core project no longer has a c++ dependency on torch, and the incremental cost of hosting more dialects and CI bots is not very high. As long as we weren’t landing pieces without use but actual functioning integrations, I could be convinced. But mainly, I was trying to be chivalrous and not be like “I think this should live somewhere else” without providing an alternative with some adjacent fellow travelers and support. I do believe that a lot of people are struggling here and was just trying to point a way out…

If it were me, with a product use case for this, I would create a new project and: fork the TFL dialects/serialization/flatbuffer conversion from TF, build out the quant mid level there, invite the tosa serialization tools in, fork the bits and bobs for tosa/TFL interop from iree, and stand up a real CI/build for it.

I don’t personally have that use case, but I understand it and would help smooth the path for people who do.

Thanks for this detailed RFC! I’ll also do a detailed review over the next day or two. Much appreciate the focus around the TFL->TOSA pathway. We’re wrestled with at least some of these problems and this could be really useful.

cc: @eric-k

Adding a few summary comments following our offline conversation:

  • Having an incubator project for the TFLite parts as @stellaraccident suggested is potentially very helpful but needs insight around licensing and ownership to get off the ground. @jpienaar do you have any suggestions on how to pursue this ? We would like to inviate any other parties working with TFLite to collaborate on this endeavor.
  • Such a separate project may need to also hold TFLite kernels and the runtime for testability, not just the dialect and MLIR translator. But this may be also accessed from prebuilt tensorflow instances as long as its possible to align hashes without much trouble.
  • In the RFC pseudocode we discussed how to define configurable rounding modes around the quantize/dequantize steps in case frameworks define a particular mode here.
  • The TFLite to TOSA legalizations are backed by an extensive functional testing infrastructure that is open source, so it would be ideal to expand it to cover legalization pathways that are currently unsupported. This may be impacted by the problems with landing changes into the Tensorflow repository - as covered in the first comment.

@rafaelubal @sabauma @eric-k please add anything I might be missed.

There’s not really a licensing issue. The code in TF is Apache licensed and the copyright is held by Google via their CLA. Any other OSS project can take/incorporate what it wants so long as it is attributed properly (and not license incompatible, which it wouldn’t be for things being discussed here). It’s of course always a good practice to exercise good hygiene – typically done for something like this by organizing the source into a directory tree that makes it clear what came from where.

I also happen to know that Google tends to be open minded about these things, and if someone asks nicely if community code that has been incorporated into one of their projects have a license refinement made so that it can be contributed elsewhere, they have a process for that and can typically help smooth this. I’d advise to decide what you want, know that the license allows that to fork without further permission, and ask nicely if things can be updated to make it better. But be specific on what you want to see happen first.

This is what I often refer to as these projects growing a number of dependencies and test integrations. You want to encourage that, which is why you want a frontend integration like this to be a leaf project (vs part of some monorepo or spanning three different codebases in an obtuse way).

I’d strongly advise that this be developed tightly with the needs of the TFL legalization you want to see happen vs as an independent piece. There are a lot of quantization approaches in the wild, and they all have their representational peculiarities. I don’t think at this level there is a useful generalization. And the way TFL does it is, in many ways, an approach that later things did not duplicate in a way that will help you create a common quant lowering utility.

1 Like

Hi, all. Thank you very much for your feedback. Here are some updates:

  • The code for this RFC is now available in the following pull request: [mlir] Improvements to the 'quant' dialect by rafaelubalmw · Pull Request #100667 · llvm/llvm-project · GitHub

  • I have updated the RFC with more detailed documentation and some minor design decisions that were made during the implementation phase.

  • @stellaraccident @jpienaar: The idea of splitting the Tensorflow MLIR infrastructure into a community-maintained repository also sounds very appealing to us at MathWorks, Stella. This topic has been brought up in the monthly TOSA meetings before, and I believe there is a plan to revisit it in the near future. If it becomes a reality, it would be worth considering the migration of this infrastructure over to such repository, or to a satellite counterpart, as suggested. Such migration should not take too much effort, given that the proposed implementation is mostly centralized within the corresponding dialect directory. In the meantime, the community can start benefiting from the refined dialect spec and tooling even if just for narrow purposes.

  • @sjarus: Support for configurable rounding modes is a valuable feature, Suraj. Given the complexity of the current implementation, and the abundance of details in need for careful review as it currently stands, let me propose to postpone such feature for a follow-up RFC and pull request. Does that sound reasonable to you?

  • Regarding functional testing, we also had the need to implement an in-house infrastructure, aimed at complementing FileCheck tests by verifying the runtime correctness of the code emitted in complex lowering passes, such as those proposed in this RFC. From our offline conversation, it looks like this has been yet another chunk of replicated work across our organizations. We’d be happy to share our contributions on this front as well.

Thank you all again for your thoughtful comments and encouragement.

1 Like

Thank you – I think it is time for me to resign as a defacto reviewer of this piece and pass the torch. Can we possibly get a collaborator actively involved with the tflite/tosa/quant work to do a first pass, detailed review. I’d be happy to take a final pass, but ideally we use this as a chance for someone invested in this path to step in and be informed from the ground up.

I’m happy to be the principal reviewer, but am on vacation this week so my time in front of a computer is sporadic.

1 Like

I think its good to see the colab, re making it a community repo, I’d be happy to schedule a second meeting there, but not sure how much actual pushing I could do :slight_smile: (interested folks can message, more folks interested and willing to help, the easier to motivate).