Rethink on approach to low precision FP types

I think we need a better design for the proliferation of low precision floating point types. I don’t have a concrete proposal but am raising the issue as discussed here:

I don’t want to block anything, and many of us here are reliant on being able to extend this type hierarchy. But we need a more flexible approach for extension.

Aside from all of the repeated boilerplate, this gives me pause. This type hierarchy was never meant to expand in this way:

return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
                   Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                   Float8E5M2FNUZType, Float8E4M3FNUZType,
                   Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
                   Float64Type, Float80Type, Float128Type>(type);
1 Like

I was always wondering why all the float types are separate types / classes.

For IntegerType, we have a single parameterizable type:

def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
  ...
  let parameters = (ins "unsigned":$width, "SignednessSemantics":$signedness);
}

Could we do the same for FloatType? There could be an enum parameter to distinguish between different types.

That being said, maybe one benefit of the current implementation strategy is that downstream projects can define their own float types more easily?

If just classifying all floats as a variable bit-field of sign + exp + mantissa would do the trick, then this works. It would be up to implementation details how to convert between them (ex. rounding types).

But you still want to match a particular variation, so the only thing this simplifies is functions like isFloat. In this case, isFloat8 has no meaning if your hardware only implements one type and not the others.

This is not a big problem with integers today because architectures have agreed on things like bit sizes, endianness and 2-complement’s, but it once was…

What do you mean by “match”? As in “match in a rewrite pattern”?

Wouldn’t match be an isa? E.g., you have a helper class for F32, so if you do isa the mantissa size etc is checked as normal, or like integer a parametrized isFloat function and then one can add canned variants. I think it just makes this check more expensive but not sure its on critical path

We were having a similar discussion in chat :slight_smile:

We could still retain that, the above would just ease adding new ones in the simple case where it fits the mold. And it’s out of my cache why we chose ints and float types to differ wrt parameterization (and also I’m just typing without any deep reflection).

The current rule is that the APFloat::fltSemantics is the thing that binds all of these types together as encapsulated by an enum like APFloat::Float8E8M0FNU().

I think that sticking to the rule that APFloat knowing how to deal with it is what makes it a FloatType is a good one. What I’d like to see is for the MLIR boilerplate needed to add a new variant to be closer in size and complexity to adding an enum/instance vs an entirely different FloatType subclass.

I think that requires externalizing the fltSemantics enum in some way better than how it is.

Having worked with these types a lot in downstream code, I do think that parser support is important: I’d rather have keywords for these vs some hard(er) to spell nested type syntax.

I bet we could get this down to basically a table somewhere giving each supported variant of the enum a name.

2 Likes

Thanks for starting this thread, Stella.

+1. This aligns with what we discussed in the code review. However, my understanding is that we would also need a set of feature flags to model some of the representational constraints of each type (e.g., support for NaN, Inf, -0.0, bias, etc.), as these constraints can’t be automatically inferred from the sign, exponent, and mantissa tuple.

We might also consider introducing a dedicated floating-point class specifically for these types. This would help distinguish them from the “first-class” floating-point types and establish a common foundation for them and restrict their usage to scenarios where they truly make sense.

Actually, I’m much more concerned about the complexity these types are adding to APFloat rather than to MLIR itself. We should think about creating a dedicated subclass under APFloatBase that helps encode the common weirdness and use cases of these types.

1 Like

I haven’t looked at it recently but would buy that it probably needs a look too. Pretty much no software survives a doubling of complexity without some attention paid to the design. I think the principle, though, is that we have a way that we are providing systematic basic services on any of these types and an APFloat type thing is what we are using for that now. It would be nice to retain one enumeration that can be passed to a factory to get an appropriate instance providing those services.

At least APFloat is an implementation detail. We’re adding a lot of public API surface area to MLIR (C++, C and Python bindings) with each of these, and that is hard to undo the longer it bakes in.

It gets worse. The same float type can be computed in different ways on the same machine, depending on which instruction you call. For example, on Armv7, VFP operations (scalar) are IEEE compliant, while NEON (vector) are not (denormals). For example, if you force a “non-denormal” semantics early on (some default behaviour), you stop any hope of vectorization.

This is a similar discussion as in LLVM for where the type semantics is encoded: the type system or the instructions. In LLVM, we’re moving everything to instructions (simple numeric types, opaque ptr, wrapping/fast-math on instructions, etc).

MLIR allows for a richer type system, but it’d also require us to convert (and validate) types when lowering. For example, a linalg.matmul on fp8 would need to lower to different implementations, depending on the underlying hardware, and that would need to be encoded in whatever IR is the output (LLVM, SPIRV).

Keeping the implementation details on the instructions, however, would allow a direct lowering and the dialect would “know” which fp8 it implements, and validate semantics locally. That’s what I mentioned above regarding “implementation detail”.

I haven’t done that work, so I’m just laying the differences. I don’t know that a rich float type system would be a nightmare to convert across different dialects, but I have a strong feeling that it will be.

1 Like

PyTorch took a pragmatic approach that is basically the same that LLVM/MLIR took (and even mirrored the names we chose) and is working: the framework will give names to needed types and will support basic features on them (parse/print, convert) but beyond that, everything else is done via kernels that operate on them specifically.

That was basically the intent of the original FP8 RFC and is the state we are still in. But the code could use some cleanup (my intent in starting this thread was to start to think about the code sprawl that comes from having a lot of something when only a few were planned for).

An approach like this combats the fragmentation and errors that would occur if just type punning to an opaque type, and it balances the concerns Renato is pointing out while leaving the option for the future to have more consensus around level of common support for some subset. I think it is about the best we can do right now without boxing ourselves into corners or trying to have comprehensive support for things that probably never need it (and where it would be really hard to even describe what that is).

I’m kind of skeptical that there is a meaningful generalization here at this juncture, but it is in MLIR’s interests to assign identity to the types that end up in common use, even if they don’t “do” much beyond having basic support for representing them.

I’ve been thinking about this for a bit. We could do the following:

  1. Store all FP type-specific information+functionality in fltSemantics. Not only fields such as sizeInBits, nonFiniteBehavior, etc. But also pointers to functions that can: (a) convert from APFloat to APInt, and (b) convert from APInt to APFloat. With that, it should be possible to make APFloat independent of the concrete FP semantics.
  2. Make fltSemantics a public struct. This would allow downstream users to create custom floating point types without having to modify LLVM. We could also organize the code a bit differently in LLVM. E.g., one file per FP semantics.
  3. In MLIR, make FloatType a type interface with one interface method: std::optional<fltSemantics> getSemantics. We no longer need to have checks like llvm::isa<Float4E2M1FNType, Float6E2M3FNType, ...>. If no semantics are provided (std::nullopt), the floating point type can still be used in MLIR, but things such as defining constants (which involves converting a parsed double to APFloat) won’t be supported.

Any thoughts?

4 Likes

I like where you’re going with #1 and #2 and think that would be an improvement. Would also be good to restructure the APFloat tests to be more “cts like” vs spaghetti unit tests.

Not sure I completely follow #3 but also not opposed to continuing that thought experiment. The main benefit here would be to have kind of an “opaque float”?

I like it, it uses what can be common without requiring all to fit mold.

Is the alternative an interface with all the hooks on it? And would all include all one can do an APFloat? Also, adding function pointers to fltSemantics vs having a struct our side (changing where LLVM is used as backend may be required, but LLVM codegen is not necessarily the backend MLIR side and so wondering if making public is introducing coupling where usage may differ).

I’m not entirely sure about #3 there, but it does feel reasonable.

I take it that, without an ApFloat conversion, constants would look like %const = arith.constant 0xabcdef : #my_research.custom_float ? That does seem useful for plumbing stuff through.

Making APFloat extendable downstream does seem reasonable

+1 to this idea. Not only downstream implementation, but we could also have upstream types that implement those interfaces.

I’m wondering if this could also be used for mixed-precision floats. Conversion to other types, load/store semantics, etc. could be done by different compilers / targets in different ways, etc.

1 and 2 look like a nice direction for improvement!

I’m not entirely following 3 though, my view of FloatType support has always involved a tight coupling with APFloat support (and this is ingrained in various places, like constants/conversions/etc.). Making changes there (such as optional fltSemantics representability) feels like a separable discussion that could explore if the merits are worth it.

Outside of non-apfloat compatibility, if we had a FloatType that took fltSemantics as a parameter, all of the Float4E2M1FNType could just be c++ sugar on top of one type (instead of an interface or a bunch of different unique types). Removing the 20 unique variants in favor of a single more uniform type would be a major uplift.

+1

1 Like

Let’s take %0 = arith.constant 1.0 : float<e2m1fn> as an example. How do we find the corresponding fltSemantics object for float<e2m1fn>?

Presumably there would have to be some static DenseMap<StringRef, fltSemantics> in FloatType with all known floating-point types. That would prevent users from adding new floating-point types without modifying MLIR. That’s why I was suggesting a type interface.

(But a single type with a parameter would already be an improvement.)

This PR turns FloatType into a type interface. (Semantics are not optional.) I did not find any measurable compilation time changes for the few lit tests that I ran. (I was hoping that it improves performance in some cases because we no longer have the long sequence of isa in FloatType::getFloatSemantics, but it seems to have no effect in the test cases that I ran.)

I wrote a micro benchmark that parses + prints 32768 floats with random floating-point type: Microbenchmark · GitHub

Going through the type interface causes a slowdown compared to the hard-coded sequence of if checks in getFloatSemantics.

Benchmark 1: mlir-opt test.mlir -allow-unregistered-dialect
  BEFORE
  Time (mean ± σ):      43.3 ms ±   1.8 ms    [User: 31.9 ms, System: 11.4 ms]
  Range (min … max):    39.8 ms …  48.3 ms    200 runs

  AFTER
  Time (mean ± σ):      50.3 ms ±   1.8 ms    [User: 38.8 ms, System: 11.5 ms]
  Range (min … max):    47.3 ms …  55.1 ms    200 runs

We have to decide whether this slowdown is acceptable or if it’s better to stay with the current way of extending the float-point type hierarchy for the sake of compile time performance.

Note: This slowdown happens only when parsing separate float attributes, not when parsing large tensor constants, dense elements attrs, etc. There should also be no slowdown when storing floats in the hexadecimal form or as a resource blob. Foldings/transformations that require getting the constant fp value as a C++ double likely have the same slowdown.

Personally, I would still be interested in this to be able to define custom fp types in a downstream project without having to modify upstream MLIR. (I have a use case for that.)

2 Likes