The arith dialect has ops for converting FP values to FP types with larger or smaller bitwidth. However, we are currently missing an op to convert between types with the same bitwidth, e.g., bf16 -> f16.
I would like to add this new op:
%0 = arith.fptofp %a : f16 to bf16
In the current design, the op supports arbitrary FP types for operand / result. Should we deprecate the old arith.extf / arith.truncf ops?
Some things to consider:
The op name arith.truncf never made sense: it’s not a truncation.
arith.extf / arith.truncf are widely used. It may be better to leave them as is.
I’d be concerned about the redundancy, I would rather setup a migration plan. Deprecating them and planning to remove them. We can easy the migration by writing a manual C++ op for TruncfOp/ExtfOp that would just be shim for the new op for example.
LLVM IR supports only a very limited set of FP types. There is only a single conversion pair where both bitwidths are the same: bf16 ↔ f16. The pattern in the above PR lowers these to bf16 ↔ f32 ↔ f16.
The general idea here is to allow arbitrary FP conversions on the arith dialect. Whether a specific conversion is possible in one step is architecture-specific. E.g., PTX supports e8m0 → bf16, but not e8m0 → f16. The legalization pass for the specific architecture can decide what kind of LLVM ops / intrinsics / NVVM ops / … to generate.
Is there a singular way to lower these? E.g., if converting from A to B, is there only 1 way? I and not sure that the intermediate f32 in the case of bf16 → f32 → f16 does not result in more precise specification numerical behavior which would be left underspecified for bf16 → f16.
For f32 → bf16 it mostly is. For many of these it is truncating mantissa or exponent bits with some rounding behavior, and extension is zero padding the mantissa or exponent bits.
For bf16 ↔ f16, I think converting via f32 is bit-accurate in each rounding mode.
f16 = E5 M10
bf16 = E8 M7
f32 = E8 M23
f32 has enough bits in exponent and mantissa to hold each bf16 and f16 value without rounding. Therefore, bf16/f16 → f32 upcasting is lossless.
The downcast from f32 → bf16/f16 may require rounding, as specified by the rounding mode on the arith.fptofp op. Whether you convert bf16 -> f32 -> f16 or bf16 -> f16 (assuming that your hardware can do it) makes no difference: the same bit pattern is produced. EDIT: I have to double-check this part for every rounding mode.
To answer your question, you could also lower the bf16 -> f16 conversion as bf16 -> f64 -> f16, but it would make no difference from a mathematical point of view.
LLVM IR allows arbitrary conversions between LLVM-supported FP types of different bitwidths. I.e., the verifier allows arbitrary FP types for ftrunc, fext, as long as they are valid LLVM types (not f4E2M1FN etc.). There’s no reason to choose bf16 -> f64 -> f16 over bf16 -> f32 -> f16. Both are supported by LLVM IR and both produce the same bit pattern, but the latter uses fewer bits and is, therefore, preferable. (Please correct me if I’m wrong.)
This PR is mainly about the modeling of FP conversions in the arith dialect. I’m not proposing any change to the LLVM dialect. There are multiple f8 and f4 types in MLIR. LLVM doesn’t support them, but some other backends (e.g., NVVM / PTX) do.
Right, these conversions can already be written with arith.extf / arith.truncf. But, e.g., this conversion can currently not be expressed with a single op in the arith dialect: e8m0 -> e5m2.
That’s what I mean, too. Ie. not moving the existing arith ops.
As @krzysz00 mentioned in the PR, we could (for now at least) only add the fpconvert (or whatever name) to just mean the “same-bit-width” conversions, and leave extf and truncf alone in the PR. That’d be easy enough to merge on its own, without needing this discussing to reach consensus.
FWIW, this is also how SPIR-V does it: OpFConvert – seems the cleanest to me without having to worry if upcasting to a larger type is faithful, if the larger type is legal at all, and what if it accidentally lives on.
We could still expose TruncFOp/ExtFOp as C++ classes for backwards compatibility in transformations / conversions, similar to how @rengolin introduced it during the named matmul op refactoring.
I changed the PR, so that only FP types with the same bitwidth are allowed for fptofp. Deprecating / removing extf and truncf can be a second if we want to go that route.
I also verified that f16 -> f32 -> bf16 produces the exact same bit pattern as f16 -> bf16.
That does not seem OK to me actually: we should have a concrete plans to remove the redundancy before we introduce it.
It does not have to be the same PR, but at least this should be part of the plan we execute on.
if (srcType.getWidth() != dstType.getWidth())
return emitError("result element type ")
<< dstType << " must have the same bitwidth as operand element type "
<< srcType;
OK, you want to limit this to same bitwidth now, I had missed it…
That said, the example in the op documentation (and the PR) is “f16 to bf16”: is equivalent to trunc(ext f16 -> f32) -> bf16 now, are we able to canonicalize this to fptofp?
Yes, that’s a valid canonicalization. Should I add it?
The general rule here would be something along the lines of: a -> b -> c can be canonicalized to a -> c if mantissa_bits(b) >= mantissa_bits(a) and exponent_bits(b) >= exponent_bits(a).
In addition a and b must have the same semantics wrt. hasZero, hasSignedRepr, hasInf, hasNaN etc. (Unless fast-math flags indicate that the input is not inf/nan/etc.)
To mirror a comment from the PR: I don’t necessarily agree that we should be removing extf and truncf at al - extension and truncation (and these sorts of conversions) have meaningfully different properties in a lot of code paths and having them as separate ops makes them easy to match on. They’ve also go meaningfully different semantics to worry about - extension doesn’t really care about rounding mode the way truncation does, for example. Tonally, we have extension and truncation as separate operations for integers, so why not on floats?