The LLVMIR dialect in MLIR currently has support for FastMath
flags (nnan
, ninf
, nsz
, …) that map directly to the LLVM equivalents for floating point instructions (e.g. fadd
, fsub
, …). I would argue that MLIR would benefit from supporting the FastMath
concepts in other places, outside of the LLVMIR dialect. (More on the “where” in MLIR below.) Specifically:
- there may be target-specific IRs that are not LLVM IR (see Tensor Codegen Thoughts, MLIR ODM 2020/01/23, slide 6)
-
FastMath
semantics in many cases map directly to higher level constructs like vectors and tensors, and transformations may wish to leverage this behavior before lowering to theLLVMIR
dialect.
There are already some cases where fast-math
-related floating point behavior ambiguity exists (or did exist) in MLIR:
- Canonicalization of ‘x + (+0.0)’ in tosa
- conversion of
complex::MulOp
inComplexToLLVM.cpp
adopts the “naive” finite math lowering approach that would result from fast-math optimizations, whereas the same conversion (to theArithmetic
dialect) inComplexToStandard.cpp
generates the (dozens of) extra instructions to correctly handle Inf/Nan values
The documentation for the floating point instructions in the Arithmetic
dialect (example here) suggest FastMath
attributes as a “distant future” TODO.
Considerations:
It seems sensible to build on what LLVM
and clang
have done here.
- Initially, the MLIR
FastMathFlags
would be identical to theLLVM
fast-math
flags. This would allow for straightforward lowering to LLVM IR. -
LLVM
allowsfast-math
flags for floating point instructions (e.g.fadd
,fmul
,fdiv
, …) as well as thephi
,select
, andcall
instructions (which, in MLIR, are not in thearith
dialect). However, it seems feasible to restrict the scope of theFastMathFlags
attribute to thearith
dialect:- LLVM optimizations to
call
instructions withfast-math
flags seemed to be limited to optimizations that leverage specific knowledge of the meaning of LLVM intrinsics or known library calls (e.q.sqrt()
) that may be inlined/LTO. It seems that the existing MLIR framework can accomplish similar optimizations on known functions (if desired), without requiringfast-math
attribute support for thestd.call
operation. - There is no corresponding
phi
node in MLIR - LLVM does perform some optimizations on
select
instructions with floating point arguments whenfast-math
flags are present. MLIR has aselect
operation in thestd
dialect. In spite of this, I would think it makes more sense to confineFastMathFlags
to thearith
dialect, as opposed to cluttering the to-be-replaced-at-some-pointstd
dialect.
- LLVM optimizations to
- Should MLIR have a.) a
fast-math
attribute with a default value of “no-fast-math-optimizations”, or b.) an optionalfast-math
attribute?- clang seems to have gone through some evolution in terms of interpreting an unset
fast-math
bit as “unspecified” vs. “intentionally unset to forbid optimizations.” It seems feasible in MLIR that pipelines could, for example, setfast-math
for all operations without that don’t have the (optional) attribute present, and keep other (numerically sensitive) operations (that have a specificfast-math
attribute present) untouched. (An alternative with more granularity would be a set of optional boolean attributes.) I think that the optionalBitEnumAttr
attribute approach would provide some flexibility.
- clang seems to have gone through some evolution in terms of interpreting an unset
Proposed changes:
- Creation of a
FastMathFlags
attribute type (more specifically, aBitEnumAttr
) in thearith
dialect - Additonal of an optional
FastMathFlags
attribute to floating point operations in thearith
dialect - Addition of a
FastMath
interface to the floating point operations in thearith
dialect, patterned after the existing interface in the LLVMIR. This interface would be used (primarily) to apply modifications to theFastMathFlags
for operations that support it. - Development of passes that use the
FastMathFlagsInterface
to add/modifyfast-math
flags for supporting operations - Progressive addition of
fast-math
-aware transforms/folding implementations for floating pointarith
operations