[RFC] Explicitly specifying NaN and signed zero semantics for minf and maxf

As seen in [Proposed breaking change/RFC] Remove min and max from -arith-expand-ops, the history of ⚙ D140856 [mlir][Arith] Remove expansions of integer min and max ops, and ⚙ D137786 Lower arith.min/max to llvm.intr.minimum/maximum, there is much headache to be had from how various platforms implement floating-point min and max, and, more specifically, their behavior with regards to NaN and signed zero.

Currently, arith.minf and arith.maxf specify they will propagate NaN and leave signed zero behavior undefined.

However, not all the lowerings (for instance, -arith-to-llvm) actually implement these semantics.

Ideally, MLIR would be able to

  1. Not generate NaN checks that users don’t need/don’t care about and
  2. Support multiple different NaN/signed zero semantics when the user does care

I propose adding the following two enums (that’ll probably go in Arith, but since other dialects may want them for their ops, a more general location might make sense)

(preserved for continuity of discussion, see [RFC] Explicitly specifying NaN and signed zero semantics for minf and maxf - #16 by krzysz00 for new proposal)

enum NaNSemantics {
  Any = 0,
  Platform = 1,
  PropagateNaN = 2,
  PropagateOther = 3
}

~~
and

enum SignedZeroSemantics {
  Any = 0,
  Platform = 1,
  Compare = 2
}

(end preserved bits)

and corresponding attributes. These attributes will become arguments to arith.minf and arith.maxf.

In both enums, Any means that the user is willing to accept any reasonable semantics for NaN or signed zero and allows us, to, for example, constant-fold min(NaN, 5) -> NaN even if that’s not what the fmin instruction on their target would do.

On the other hand, Platform means that the user wants their min and max operations lowered to some target-specific operation. Since this means that the user expects some consistent semantics (but we don’t know what they are), Platform semantics impede constant folding away min and max involving NaN or stuff like min(-0, +0).

Lowerings of arith.minf/maxf will use these attributes to select appropriate operations on their target.

Lowerings may use fast math flags to, for example, skip NaN checks when the operation is tagged as being NaN-free.

Higher-level dialects like vector or tosa that have some notion of floating-point min and max will

  • Set particular values for these attributes if they have defined the NaN and signed zero semantics of their ops or
  • Expose these attributes themselves, allowing dialects further up the chain to choose them or
  • Do both, having an overridable default NaN/signed zero sematics

This work could also be a good time to pull fastmath flags up through those dialects.

Thanks for considering and spelling this case out. The way the proposal differentiates between Any and Platform makes sense to me.

Can we further assume that whatever the platform semantics are, they give consistent results? EG can we freely duplicate minf ops?

Thanks for disentangling this! As discussed before, this approach sounds good to me. A couple of comments:

  1. I think the representation of the NaN/±0.0 handling in the vector dialect (i.e., vector.reduction, vector.multi_reduction) should be addressed at the same time as in the arith dialect. Scalar min/max ops and vector min/max reduction ops go hand-in-hand and we shouldn’t change the representation of the former without changing the latter. Otherwise, we would end up in a situation where a scalar code can’t be vectorize due to the NaN/±0.0 handling representational gap between the two.

  2. Could you please elaborate a bit more on the Platform use case? I would expect these NaN/±0.0 semantics to be defined by the frontend and not by the platform or target. If the platform/target happens to have a native operation for the requested semantics, then the native operation should be used. Otherwise, the requested semantics should be emulated somehow. If the requested semantics are Any, then the platform/target could choose if using its native operations is the right thing to do. Similarly, Any semantics could be turned into the platform/target specific semantics (i.e., PropagateNaN, PropagateOther) by a platform/target specific pass. I don’t think lowering decision should be encoded in the operations themselves but maybe I’m missing something.

  3. Perhaps renaming NaNSemantics::Any to NaNSemantics::PropagateAny would make it clearer.

Platform implies there is one true operation and that these are not distinct operations which can have both. You’ve also ignored signaling nans. I also think changing behavior based on the platform is not helpful behavior

  1. Re higher-level dialects, my thought was that their reduction operations should take the same sort of attributes
  2. PropagateAny is probably a reasonable renaming
  3. The usecase I imagined for Platform (which could also be Target or Native or any number of similar names) is this: I am a person writing some code that wants fmax or fmin. I want this code to compile down to whatever my target hardware’s native implementation of fmin/fmax is because I either want those exact NaN semantics or because I really don’t care about the NaN semantics but do want them to be consistent within my program.

That is, the way I defined Any above is roughly equivalent to “max(NaN, …) is permitted to cause nasal demons” while Platform is “I want some defined behavior, I don’t care what it is”.

Though, having written that out, we can simplify down to PropagateAny and checking the fastmath flags during folding.

It’s still not clear to me how we would define the exact semantics for Platform. For example, if we assume it may have some internal state that decides the order of NaNs/0s, I don’t see that we could even fold or reorder minf/maxf of regular non-NaN/0 numbers – it could be that all inputs somehow feed this internal state. Is it any different than making Platform an external function or a target-specific intrinsic?

Platform semantics mean you can’t constant-fold exactly because we don’t know how a NaN will be handled. To allow things like reordering, I’m willing to impose “the thing you lower to needs to operate consistently” … but, then again, when it comes to reordering floating-point operations or CSE-ing them … have we been making the assumption that no one ever does call fiddle_floating_point_flags() while doing those optimizations?

Because I can see today’s MLIR doing the rewrite from

%2 = flt_op(%0, %1)
...A
mess.with.floating.point.mode.registers
...B
%3 = flt_op(%0, %1)
...C
%2 = flt_op(%0, %1)
...A
mess.with.floating.point.mode.registers
...B
...C

The main reason I defined a Platform case is that I wanted to avoid the case where MLIR constant-folds a PropagateAny operation one way but your execution environment would’ve done the other thing, leading to “my code has different results between constant folding and not” cases.

Platform semantics mean you can’t constant-fold exactly because we don’t know how a NaN will be handled.

Why is this a desirable feature to expose? Can you elaborate on the use-case a bit more?

Answering to myself: this would ensure consistency to have the same result regardless of whether the compiler optimizations would lead to a constant folding happening or not. Whereas with “any”, the result of the computation could change based on optimizations (if the platform does not run the same way as the constant folding).

On the other hand, there’s the question of which of the permissive modes we actually want. That is, does most code that lowers to MLIR come from a source that defines particular NaN semantics? (Ex. does Tosa define them, or do systems like XLA or PyTorch specify how their min and max interact with NaN?)

… having looked, PyTorch, for example, does provide multiple flavors of NaN handling in its fmin/fmax-equivalents, but doesn’t seem to specify ±0 behavior.

So it’s possible that the only modes we actually need to support are the ones where someone specifies the NaN semantics, because anyone who doesn’t know which ones they want should go think about that.

… On the other hand, you get cases like the SPIR-V standard, which, if the comments in the SPIR-V lowerings are correct, explicitly says that min/max on NaN is unspecified, and that’s something I’d think people would want to represent in their programs too.

Note that as specified, Any would disallow one optimization that is allowed by Platform: duplicating the operation. If we have %x = fmin(...) and x is used twice, then with Platfom it is fine to re-compute x for each use, but with Any that would be incorrect since different invocations of fmin on the same inputs can produce different results (depending on whether they are const-folded or not).

Good point, and I think it goes beyond constant folding: other kind of pattern-matching could take advantage of Any going one way or another.

It seems to be the same feature as IR dialects. There are a wide area of programming models and devices in MLIR. They all have different ideas about floating point operations. Maybe out of the box MLIR offers IEEE semantics and there is some mechanism to introduce math dialects.

Instead of a+b you will do sycl->add(a, b).

@andykaylor

+1 for this. I think we could use this in XLA, which does allow users to specify flags on what kind of NaN semantics they want to use. I am also a bit confused between Any and Platform (do we actually need both, or are fine with just one of them? - skimming through the comments, it seems we definitely want Platform, but do we need Any?), but would need to think about it more to properly understand the tradeoffs.

XLA specifies NaN semantics: propagate NaN by default, but allows users to specify a flag if they want to generate faster code by allowing propagate other. It does not specify +/-0.0 semantics, and generally treats signed zeros the same (though take the 0 semantics statement with a grain of salt, I’m drawing this conclusion from my casual conversations with people, don’t know what actually happens in code).

Thanks for all the good discussion, y’all!

It seems like Any and Platform are confusing, hard to reason about, and don’t necessarily capture the case I was aiming for.

So therefore, version 2 of the enums

enum NanSemantics {
  Unknown,
  PropagateNan,
  PropagateOther,
};

enum SignedZeroSemantics {
  Unknown,
  Ordered
};

In these new versions, Unknown semantics (which could also be called Undefined or Unspecified semantics) means that the min/max operation could implement either propagation behavior (or in the case of min/max(-0, +0) could return either result) and so constant folding the problematic cases is not permitted.

When lowering from Arith, Unknown semantics may be replaced by any fmin/fmax implementation.

1 Like

or perhaps Unordered?

Arguably, if someone requests unordered NaNs, this seems perfectly fine.

Unordered doesn’t feel right for NaN semantics, but it could work for signed zero.

I’d think that we’ll want to assume that maxf and minf are pure and can be subject to duplication or common subexpression elimination - I can’t think of any mode bits that affect their operation (aside from signaling vs quiet nan, which is a bit of a minefield anyway)

I don’t think so. This is not even about minf semantics any more: let’s say we convert the result of the minf into an integer x. Then x == x should always be true (for non-poison/undef x). To ensure this remains the case, non-deterministic operations (such as minf with Any semantics) must never be duplicated.

IOW, using C-like syntax, the following two programs are not the same, and transforming the first into the second is wrong:

int x = (int)minf(a, b);
bool b = x == x; // always true for any non-poison/undef integer x
bool b = (int)minf(a, b) == (int)minf(a, b); // can be false

I don’t think so. This is not even about minf semantics any more: let’s say we convert the result of the minf into an integer x . Then x == x should always be true (for non-poison/undef x ).

I’m not sure if working with C helps to understand this better here. If minf(x, NaN) may return NaN in the first place, then we can’t expect equality even before casting to int, no? And if we cast to int, are we always guaranteed the same bit-pattern?

From C §6.3.1.4:

If the value being converted is outside the range of
values that can be represented, the behavior is undefined.