[RFC] (Semi-)software bf16

Downstream, we recently changed our code generation pipeline to use the bf16 type when talking about bfloat16 as opposed to referring to it as i16 everywhere. The reason we’d had that hack in the first place, and the reason for this RFC, is that various LLVM backends (definitely AMDGPU and X86) barely support the LLVM bfloat type.

So, I put together a fix that operates on the LLVM conversion and adds passes after it (see the setup and the rewrites, that:

  • Replaces every mention of bf16 with i16
  • Converts all bf16 constants to i16 constants by bitcast
  • Implements truncf : f32 -> bf16 with bit-bashing
  • Incorrectly lowers arithmetic operations (this’ll be fixed soon)

Now that I’ve poked at some other bits of our codebase, I’ve come to the conclusion that this bf16 remover could be a generally useful utility if expanded and cleaned up a bit.

So, how would folks feel about a pass (and change) over LLVM that

  • Swaps bf16 for i16 everywhere, bitcasting constants
  • Implements truncf/extf to f32 with bitwise integer operations
  • Implements everything else (the other float casts, arithmetic, etc.) by extending to f32, doing the operation, and truncating the result

(Alternatively, we could try and write the bf16 ops in software but that’s a lot of bother and potentially slower)

One other possibility I can see is that this would make more sense as a pass in LLVM - targets that don’t understand bfloat could then stick the debfloater in front of themselves, and then none of this lives in MLIR, but I thought I’d propose this here first.

A coworker of mine has put together the pass I proposed here, and I’d like to bump this up to request review on the revision ⚙ D126444 [mlir]Implement SoftwareBF16 to handle the bf16 type