Background
The Open Compute Project has defined a new standard for block-scaled microfloat (MX) data types.
This format defines datatypes that consist of a block of micro-floats (such as the 4-bit E2M1 type or the 6-bit types E3M2 and E2M3) along with a shared scale, which holds 8 exponent bits and no mantissa.
The intended usage of these microfloats is as part of a block-scaled type:blocks of K (with K=32 being required by the specification) microfloats along with their shared scale. This is intended to provide efficient matrix multiplication and to allow the compressed storage of tensors, most typically the weight tensors in machine learning models.
Since there’s currently software emulation for using these types in PyTorch and that hardware vendors will likely be adding support for these MX types in the future (on account of having signed on to the spec), it’ll be useful to have these types representable in MLIR’s core library in the interests of interoperability.
The types
The components of a microscale block are somewhat like floats, in that they have sign, exponent, and mantissa bits. However, due to their small size, the individual components don’t have infinities or NaNs. When a NaN is needed, it is applied to entire block of microfloats by setting the shared scale to 0xff.
The alternatives
Between their very limited range, sub-byte length, lack of special values, and the fact that the OCP specification doesn’t comtemplate operations on individual microfloats, it’s not clear that individual microfloat types (like a hypothetical FloatE2M1FN
) should be added to MLIR’s FloatType hierarchy (which implies their addition to APFloat as a prerequisite matter).
However, adding those scalars to FloatType would enable microfloats to be handled with much of MLIR’s existing infrastructure. For example, a block of 4-bit floats could be a vector<32 x fE2M1FN>
instead of needing a custom type.
An alternative approach would be to define a new microscaling
dialect and add custom types like !microscaling.fe2m1
and !microscaling.block<32xfe2m1>
and !microscaling.scale
. When combined with a sufficiently general bitcast operation, this would allow defining operations on microscale types (which, at the end of the day, will still lower to an appropriate-width integer (or vector of them) since there are no plans to add these microfloats to LLVM or SPIR-V’s type systems to my knowledge).
However, because these custom types would sit outside of the float type hierarchy, they wouldn’t be permissible in many high-level dialects like Tosa, creating substantial additional friction around the process of generating a high-level description for model fragments that are meant to use microscaled floats on their inputs or outputs.
I’m not seeing a clear best path forward on this design question, so I’m writing to get the opinions of other people who’ll be hooking up microfloat support in MLIR at some future time.