[RFC] Fix floating-point `max` and `min` operations in MLIR

This proposal focuses on arith.maxf, arith.minf, vector.reduction <maxf>, and vector.reduction <minf> in MLIR. We have identified certain issues that require attention, which are outlined in the following sections:

Problem 1: Missing Semantics and Confusing Naming

In LLVM, there are two ways to find the minimum and maximum of two floating-point numbers: minnum/maxnum intrinsics and minimum/maximum intrinsics [1, 2]. Each pair exhibits different behavior when dealing with NaNs and +0.0/-0.0. The former returns the non-NaN argument when one of the arguments is NaN and does not distinguish between +0.0 and -0.0, whereas the latter propagates the NaN and does distinguish between +0.0 and -0.0. The same applies to the reduction variants [9, 10].

In MLIR, there is currently a single way to represent the minimum and maximum of two floating-point numbers: arith.minf/arith.maxf. The semantics of these operations adhere to the semantics of minimum/maximum intrinsics in LLVM [3]. The same applies to the reduction variants in the Vector dialect [11]. However, there are no variants in MLIR to model the semantics of LLVM’s minnum/maxnum.

Proposed Solution

We propose to rename arith.minf/arith.maxf to arith.minimum/arith.maximum to align their names with their LLVM counterparts. We would also like to introduce the corresponding arith.minnum and arith.maxnum operations to mirror the semantics of the minnum/maxnum intrinsics in LLVM. The same applies to the vector reductions operations in the Vector dialect: vector.reduction <maxf> and vector.reduction <minf>.

Problem 2: arith.minf and arith.maxf lowering to LLVM

Currently, there is a bug in the lowering of arith.maxf and arith.minf to LLVM. Although their semantics [3, 4] clearly indicate that NaN should be propagated and +0.0/-0.0 should be differentiated, they are lowered to minnum/maxnum intrinsics without the proper handling of the cases described above [5]. Interestingly enough, the lowering of their vector reduction variants seems to have always been aligned with the described semantics [6] even before it was changed to the proper intrinsic lowering.

Proposed Solution

We plan to fix the existing bug by changing the lowering of arith.maxf and arith.minf from the minnum/maxnum intrinsics to the maximum/minimum intrinsics in LLVM.

Problem 3: Vector reductions lowering to SPIR-V

A quite similar bug can be identified in a distinct combination of operations and lowerings: vector reduction operations and their corresponding SPIR-V lowerings. While the lowerings utilize spirv.CL.fmax and spirv.CL.fmin operations, the documentation specifies that they are intended to exhibit behavior similar to the llvm.m**num intrinsics [7]. However, there are no additional operations inserted to rectify the behavior as previously done in similar cases [8].

Proposed Solution

We’re going to fix the bug by adding extra operations to make sure the right meaning is carried through, just like how it was done before for the LLVM lowering [6].

Summary

State before the changes

Operation LLVM lowering SPIR-V lowering
arith.maxf :x:llvm.maxnum without enforcing the desired semantic :white_check_mark: Follows the semantics
arith.minf :x:llvm.minnum without enforcing the desired semantic :white_check_mark: Follows the semantics
vector.reduction <maxf> :white_check_mark: llvm.vector.reduce.fmaximum :x:sequence of spirv.CL.fmaxes without enforcing the desired semantic
vector.reduction <minf> :white_check_mark: llvm.vector.reduce.fminimum :x:sequence of spirv.CL.fmins without enforcing the desired semantic

State after the changes

Operation Desired LLVM lowering Desired SPIR-V lowering
arith.maximumf llvm.maximum Current arith.maxf lowering
arith.minimumf llvm.minimum Current arith.minf lowering
arith.maxnumf llvm.maxnum spirv.CL.fmax spirv.GL.FMax + additional checks to propagate non-NaN (*)
arith.minnumf llvm.minnum spirv.CL.fmin spirv.GL.FMin + additional checks to propagate non-NaN (*)
vector.reduction <maximumf> llvm.vector.reduce.fmaximum sequence of spirv.CL.fmaxes + additional checks to propagate NaN
vector.reduction <minimumf> llvm.vector.reduce.fminimum sequence of spirv.CL.fmins + additional checks to propagate NaN
vector.reduction <maxf> llvm.vector.reduce.fmax sequence of spirv.CL.fmaxes
vector.reduction <minf> llvm.vector.reduce.fmin sequence of spirv.CL.fmins

* A note on SPIR-V lowerings for Arith operations with m**num intrinsics

There are two lowerings from the Arith dialect to SPIR-V: spirv.CL and spirv.GL. The spirv.CL operations behave similarly to the llvm.m**num intrinsics when it comes to handling NaNs. However, the spirv.GL operations have undefined results when one of the operands is NaN. To ensure consistent semantics, additional operations should be inserted in the spirv.GL lowering to enforce the correct behavior, like the current lowerings but with the different intent.

Work breakdown

1 Arith dialect

1.1 Change the lowering to llvm.m **imum intrinsics for arith.m**f operations.

1.2 Rename arith.m**f to arith.m**imumf.

1.3 Add arith.m**numf operations.

1.4 Add arith.m**numf LLVM lowerings.

1.5 Add arith.m**numf SPIR-V lowerings.

2 Vector dialect

2.1 Rename vector.reduction <m**f> to vector.reduction <m**imumf>

2.2 Fix SPIRV lowering for vector.reduction <m**imumf> to propagate NaNs.

2.3 Add vector.reduction <m**f> operations.

2.4 Add vector.reduction <m**f> LLVM lowerings.

2.5 Add vector.reduction <m**f> SPIR-V lowerings.

References

  1. llvm.maxnum intrinsic
  2. llvm.maximum intrinsic
  3. arith.maxf Reference
  4. arith.minf Reference
  5. arith.maxf and arith.minf LLVM Conversion tests
  6. Vector Reduction example with explicit enforcement of semantics
  7. spirv.CL.fmax Reference
  8. Vector Reduction operations SPIR-V lowering tests
  9. llvm.vector.reduce.fmax intrinsic
  10. llvm.vector.reduce.fmaximum intrinsic
  11. The implementation of the vector reduction intrinsic has always been aligned with the llvm.vector.reduce.fm**imum intrinsics [6].

Authors: @dcaballe and @unterumarmung
CC: @kuhar, @mehdi_amini, @nicolasvasilache, @ftynse and @banach-space

2 Likes

Great write-up, thank you for sharing!

This is not my area of expertise, but all of this sounds very reasonable to me. Big +1 for alignment with LLVM.

-Andrzej

2 Likes

I also came across this issue, and I support this proposal. With respect to the current status quo of arith, I am fully in favor of alignment with LLVM.

In light of arith vs llvm, I’d like to complicate this issue a little bit further though. For reference, I studied IEEE-754 on the matter, but got a sombering result: partly because implementations disagree wildly on the treatment of NaN, the standard has dropped minNum and maxNum. In other words, there’s little else to align them to.

Under the assumption that we might want arith to be a more high-level arithmetic dialect in future, this means operations like min and max are going to need our own spec. Additionally, I’d consider the existence of two different operations like in your proposal a downside for a high-level dialect.

2 Likes

Makes sense to me!

2 Likes

There are now 3 IEEE 754 standards }{1985, 2008, 2019} Where 2008 and 2019 both contain aborted 1987 standard for non-binary radix floating point. Not referring to which adds to the chaos.

IEEE 754-2008 introduced the concept of MAX() and MIN() delivering the non-NaN component (if any)

IEEE 754-2019 deprecated the 2008 semantic in favor of the “old” semantic everyone and his brother understood. 2019 added new intrinsic names for the deprecated 2008 MAX() and MIN() semantics. So while they are no longer associated with MAX() and MIN() they remain available via an intrinsic with a long name MaximumNumber()…

I suspect that the 2008 versions will be deprecated away in some future standard.

3 Likes

Thank you all for the feedback!

That’s a great point! Yes, this topic is complex for a number of reasons, including changes to and compliance with IEEE-754, as well as the fact that different frameworks/front-ends may choose to align with IREE-754 or have completely ad-hoc semantics for these operations. I would say that our goal with this RFC is more about fixing existing bugs in MLIR and aligning the representation with the most common backend compilers we support (i.e., LLVM and SPIR-V). We can always remove minNum and maxNum if LLVM does so, but minNum and maxNum semantics are used in more aggressive fp precision modes and align with the semantics in some vendor’s ISAs, so I would expect them to be around for some time.

We discussed in the past the possibility of adding more optionality to min/max, particularly regarding signaling and quiet NaNs, and having a single min/max operation with an attribute encoding such optionality. However, things can quickly become complicated, as the difference in semantics that this optionality may introduce would lead to different optimization constraints and even side effects. I am fully supportive of having more optionality (driven by actual use cases) and a more powerful representation, but that would be beyond the scope of this RFC. Again, our primary goal here is to fix the existing bugs and keep the “basic” min/max representation consistent. We can then talk about what to do next!

To be honest, I don’t think they are going anywhere because these intrinsics are got into LLVM for a reason. And I think the reason is that it is a behaviour of the standard functions in C from <math.h>: fmax, fmaxf, fmaxl - cppreference.com, fmin, fminf, fminl - cppreference.com

Thanks for the details, I never followed this to the end.

The point I was trying to make is more how a high-level arithmetic dialect should even come to a unified definition of functions like these. Remember that we also have

  • BFloat16, which at least in Google’s TPU paper reserved the possibility of diverging from IEEE-754 operation semantics, and does for denormals
  • Float8E4M3FNUZ and friends, which also only have 1 NaN representation, and thus no sNaN / qNaN distinction by payload is possible

In essence, my argument is similar to what was made in the arith vs llvm thread earlier: the llvm dialect already offers mappings to the standard-compliant functions. The added value of arith over llvm can only be in decoupling. We may therefore either preserve a snapshot of the current semantics, or define an agreeable common baseline, but will shoulder maintaining a conversion regardless.

Quick note: We should also take atomic rmw minf/maxf ops into account as part of this effort: ⚙ D158283 [mlir][MemRefToLLVM] Add fmin, fmax to AtomicRMW lowering

1 Like