Downstream, we recently changed our code generation pipeline to use the bf16
type when talking about bfloat16 as opposed to referring to it as i16 everywhere. The reason we’d had that hack in the first place, and the reason for this RFC, is that various LLVM backends (definitely AMDGPU and X86) barely support the LLVM bfloat
type.
So, I put together a fix that operates on the LLVM conversion and adds passes after it (see the setup and the rewrites, that:
- Replaces every mention of
bf16
withi16
- Converts all
bf16
constants toi16
constants by bitcast - Implements
truncf : f32 -> bf16
with bit-bashing - Incorrectly lowers arithmetic operations (this’ll be fixed soon)
Now that I’ve poked at some other bits of our codebase, I’ve come to the conclusion that this bf16 remover could be a generally useful utility if expanded and cleaned up a bit.
So, how would folks feel about a pass (and change) over LLVM that
- Swaps
bf16
fori16
everywhere, bitcasting constants - Implements
truncf
/extf
to f32 with bitwise integer operations - Implements everything else (the other float casts, arithmetic, etc.) by extending to
f32
, doing the operation, and truncating the result
(Alternatively, we could try and write the bf16
ops in software but that’s a lot of bother and potentially slower)
One other possibility I can see is that this would make more sense as a pass in LLVM - targets that don’t understand bfloat
could then stick the debfloat
er in front of themselves, and then none of this lives in MLIR, but I thought I’d propose this here first.