LLVMIR supports basic floating-point types such as f16, bf16, f32, etc. But many low-precision FP types that have been added to MLIR over the years (such as f8E4M3FN, f4E2M1FN, …) are not supported. Arithmetic operations on such types (such as arith.addf) fail to lower to LLVM. (Or invalid IR is produced in some cases.)
LLVM has a software implementation (APFloat) of all floating-point types that are supported by MLIR today. This RFC proposes to add a fallback lowering to -convert-arith-to-llvm that calls into the runtime library of the execution engine, so that arithmetics with low-precision FP types can be executed on CPU.
I am mainly interested in this feature from a verification point of view: I have a compilation flow that compiles programs for GPUs and would like to implement a second (simpler) lowering that runs on CPUs to check the correctness of the compiler. The APFloat software implementation will be quite slow, but fast enough for small verification test cases.
Current state: Unsupported FP types are converted to integer types with the same bitwidth during conversion to LLVM. arith operations on such types are not supported. (Or lowered to invalid IR.)
In the case of arith.addf, instead of emitting an llvm.fadd, a call to the runtime library is inserted if the operation has unsupported LLVM FP types.
New functions are added to the runtime library: one for each arith operation (more or less). The runtime library function takes the FP operands in the form of i64 values (irrespective of the actual FP bitwidth), an int32_t that specifies the APFloatBase::Semantics enum value (i.e., the FP type) and all required flags such as rounding modes. In case of a FP type that has fewer than 64 bits, the bit pattern is zero-extended to i64. This is to make sure that the runtime library function works with all FP bitwidths up to 64.
We may want to hide the extra runtime library functions behind a compilation flag. (Because there will be many of them.)
The new lowering during --convert-arith-to-llvm is on an opt-in basis (pass flag), so that the inefficient software implementation lowering is not accidentally triggered.
Example
// RUN: mlir-opt %s --convert-to-llvm
// Put rhs into separate function so that it won't be constant-folded.
func.func @foo() -> f4E2M1FN {
%cst = arith.constant 5.0 : f4E2M1FN
return %cst : f4E2M1FN
}
func.func @entry() {
%a = arith.constant 1.0 : f4E2M1FN
%b = func.call @foo() : () -> (f4E2M1FN)
%c = arith.addf %a, %b : f4E2M1FN
vector.print %c : f4E2M1FN
return
}
I think this is a nice fallback useful for testing and verification. Only concern would be if folks expected to have a native lowering (e.g., they added a pattern or some support partially) and it didn’t trigger. Perhaps flag to guard/opt-out?
You mean mixed lowering? Like someone adds their own lowering for arith.subf but wants to use this fallback for everything else (including arith.addf)? Seems unlikely - if you’re gonna implement a lowering, I don’t think you’re gonna do it piecemeal. Also seems tough to support this kind of granular opt-out via pass args (I guess you could pass a comma separated list…). Maybe the cpp API for adding the pass can take a callback?
Not piece wise per op, but in total: either fine with using fallback or not. So if testing the lowering you’d only have the supported ops in the tests and you’d want it to fail if it needs fallback. Additionally some folks may never want to fallback when lowering to LLVM as it’s not meant as a high performance path and they may want to handle it with own pass or error out instead.
I guess we want this as a different granularity than passes. Like does ConvertToLLVMInterface pick these patterns up? Do we have a function to populate a set with them separately from the rest?
One separate question from my side: so far, various runtime libraries we have under lib/ExecutionEngine did not depend on LLVM libraries. If the new library needs APFloat, it would have to depend on on LLVMSupport. Are there any negative repercussions of that? Vaguely thinking about release/binary distributions.
On CPUs, this is mostly important for vector types, not scalar, which will always be slow.
Vector types would have “native” support on some, but not all, micro-architectural variants for some, but not all, operations. Those operations can also return a different type (ie. i8 → i32) for some, but not all, operations. So this guard would have to be “per op, per type, per u-arch” kind of thing.
Without target information is impossible to predict what will happen even in the best of cases (ie. when you know what to do). Of course, a fixed pipeline could make that choice orthogonally to the IR, but that still needs to convert some, but not all, arith to naive LLVM.
Alternatively, specialized lowering can happen before final arith-to-llvm to a form that is not naively lowered, and then when the pass finally runs, it will only see/change the left-overs.
Yes, sorry I missed opt-in there this morning. That’s exactly what I wanted.
Yes, so what I was is just for this to be able to fail post. Matthias’s suggestion does that. Which was my only concern.
That’s a good question: I’d have this not be on the interface. A separate populate method? Else it feelts like ConvertToLLVMInterface becomes a pass registry with options, while I think of it as last stage conversion. So I’d do the same as above (lower to runtime calls) and then run via interface conversion.
Wouldn’t this be on JitRunner or the like and ExecutionEngine needs to be able to support registered libraries but its on who bundles up ExecutionEngine as to which library to include?
Why not use the existing patterns within the Arith dialect that emulate unsupported floats at the IR level rather than relying on runtime calls. These live in Arith/Transforms/EmulateUnsupportedFloats.cpp and can be extended based on need.
I don’t see why this needs to touch LLVM conversion. You could have an MLIR-level pass similar to EmulateUnsupportedFloats except that it would rewrite operations on unsupported float types to calls to something like _mlir_addf_f8E5M2 or _mlir_extf_f4R2M1FNU_f16 (using appropriate bitcasts). Then, you could link in implementations of those functions that call APFloat.
This removes the need for special handling in core compiler constructs and makes this just one of the strategies available for dealing with weird float types.
(Heck, if you want to overengineer it a bit, make an apfloat dialect that lowers to C++ invokes and such in LLVM, so you can just link in the APFloat class directly)
That is, to clarify, I’m not opposed to a thing that rewrites arith ops to APFloat calls (on a set of types of your choice) but I think that’s nontrivial and violates the “point” of a *-to-llvm pass.
It’s like a pun on to-llvm - it’s not to LLVM IR but to LLVM::APFloat. Anyway one reason it should be in arith-to-llvm (instead of arith-to-apfloat) is to prevent someone from trying to use the pass in arbitrary places in a pipeline (because with arith-to-apfloat it ceases to be clear that this should be the last stop before LLVM IR). There’s also the issue that Alex brings up - a user will need to be mindful that these func.calls are implemented by a library that needs to be linked. But anyway those are documentation issues, not technical issues, so I’m +1 on it being a separate pass.
We lower to func.calls in a library a decent amount of the time - anything that conjures up malloc() or the various gpu-to-*.
And yeah, arith-to-apfloat being part of the last steps (in the rough vicinity of stuff like math-to-llvm or arith-to-amdgpu) makes sense more than putting it in arith-to-llvm which gets used all over the place.
Dialect conversion is not dissimilar to Instruction Selection, and when lowering to LLVM we need to have target specific lowerings. The convert-to-llvm pass is structured to allow target to inject custom logic. That said LLVM SelectionDAG has also enough logic by default to break down unsupported construct into more primitive ones, this kind of expansions patterns seems to me to fit this category as well. This is also why we have the concept of “priority” for patterns: we can have these fallbacks available in the convert-to-llvm to make it all “just work” by default, while allowing target-specific lowering to intercept with higher priority.
Edit: actually while a default expansion “emulating” something using raw LLVM primitives makes sense, lowering to some specific runtime calls seems like more like an opt-in mechanism.
TBH, I haven’t seen much of a problem in lowering odd-size floating point operations with target intrinsics / operations (during conversion to LLVM). But I can see this as an extra option we can leverage. It doesn’t prevent users from lowering to target intrinsics either – in the worst case we can always run this conversion (to function call) first and replace those (or just some of those) function calls with target intrinsics. Assuming those function calls preserve every single information / semantics of the original operations, that is.
why not just always lower to llvm instructions/intrinsics and just change llvm to lower unimplemented operations to libcalls or inline code for simpler cases (e.g. fabs)?
I believe the problem here is that LLVM doesn’t even have fp8 / fp4 types. So at the end of the MLIR pipeline we’ll eventually generate invalid LLVM IR instructions like a fpext from i8 to float. In other words, we couldn’t even lower to LLVM IR in the first place without doing something on the MLIR side. And this RFC is proposing a potential way
Having this lowering as a separate --arith-to-apfloat pass makes sense to me. Some examples how this could be used:
--convert-to-llvm --arith-to-apfloat: All unsupported arith ops that failed to lower to LLVM by the first pass will lower to APFloat-based emulation.
--arith-to-apfloat --convert-to-llvm: All arith arithmetics are lowered to APFloat-based emulation. The --convert-to-llvm pass picks up all remaining ops such as arith.select.
If users want more fine-grained control, we can expose the lowering patterns in populate… functions.
One open question is whether --arith-to-apfloat should emit LLVM dialect ops or func/arith dialect ops. In the latter case, example (1) would require a second --convert-to-llvm.
I was surprised to see that my prototype worked without adding any new linker rules in the CMake configuration. It seems like we are already linking with the LLVM support library. I actually couldn’t find such a CMake rule, but we have this in the Bazel configuration: