Rethink on approach to low precision FP types

The FloatType interface change has been merged.

I am now looking at a few further cleanups on the MLIR side. Some fundamental types get special treatment in MLIR, and I’m wondering if low-precision FP types should get that treatment. Especially, given that the list is expected to grow further.

Do we need to have all floating-point types in the Builder API?

class Builder {
  // Types.
  FloatType getFloat4E2M1FNType();
  FloatType getFloat6E2M3FNType();
  FloatType getFloat6E3M2FNType();
  FloatType getFloat8E5M2Type();
  FloatType getFloat8E4M3Type();
  FloatType getFloat8E4M3FNType();
  FloatType getFloat8E5M2FNUZType();
  FloatType getFloat8E4M3FNUZType();
  FloatType getFloat8E4M3B11FNUZType();
  FloatType getFloat8E3M4Type();
  FloatType getFloat8E8M0FNUType();
  FloatType getBF16Type();
  FloatType getF16Type();
  FloatType getTF32Type();
  FloatType getF32Type();
  FloatType getF64Type();
  FloatType getF80Type();
  FloatType getF128Type();

Users can always write b.getType<Float8E4M3FNType>() or Float8E4M3FNType::get(b.getContext()) instead of b.getFloat8E4M3FNType(). My thinking is to remove all but the most commonly used floating-point types (and the ones that are valid LLVM types) from the Builder API. I.e., what would be remaining: BF16, F16, TF32, F32, F64, F80, F128.

We could even go as far as removing all get...Type() from the builder API. But that would be a quite drastic change, affecting lots of code. So probably not…

Do we need to cache all floating-point types in MLIRContext?

Or just the most frequently used ones. (Same list as above: BF16, F16, TF32, F32, F64, F80, F128.)

We currently cache these types for faster lookup:

class MLIRContextImpl {
  /// Cached Type Instances.
  Float4E2M1FNType f4E2M1FNTy;
  Float6E2M3FNType f6E2M3FNTy;
  Float6E3M2FNType f6E3M2FNTy;
  Float8E5M2Type f8E5M2Ty;
  Float8E4M3Type f8E4M3Ty;
  Float8E4M3FNType f8E4M3FNTy;
  Float8E5M2FNUZType f8E5M2FNUZTy;
  Float8E4M3FNUZType f8E4M3FNUZTy;
  Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
  Float8E3M4Type f8E3M4Ty;
  Float8E8M0FNUType f8E8M0FNUTy;
  BFloat16Type bf16Ty;
  Float16Type f16Ty;
  FloatTF32Type tf32Ty;
  Float32Type f32Ty;
  Float64Type f64Ty;
  Float80Type f80Ty;
  Float128Type f128Ty;
  IndexType indexTy;
  IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
  NoneType noneType;

Any thoughts?

1 Like