The FloatType
interface change has been merged.
I am now looking at a few further cleanups on the MLIR side. Some fundamental types get special treatment in MLIR, and I’m wondering if low-precision FP types should get that treatment. Especially, given that the list is expected to grow further.
Do we need to have all floating-point types in the Builder
API?
class Builder {
// Types.
FloatType getFloat4E2M1FNType();
FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
FloatType getFloat8E4M3B11FNUZType();
FloatType getFloat8E3M4Type();
FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
FloatType getF32Type();
FloatType getF64Type();
FloatType getF80Type();
FloatType getF128Type();
Users can always write b.getType<Float8E4M3FNType>()
or Float8E4M3FNType::get(b.getContext())
instead of b.getFloat8E4M3FNType()
. My thinking is to remove all but the most commonly used floating-point types (and the ones that are valid LLVM types) from the Builder
API. I.e., what would be remaining: BF16
, F16
, TF32
, F32
, F64
, F80
, F128
.
We could even go as far as removing all get...Type()
from the builder API. But that would be a quite drastic change, affecting lots of code. So probably not…
Do we need to cache all floating-point types in MLIRContext
?
Or just the most frequently used ones. (Same list as above: BF16
, F16
, TF32
, F32
, F64
, F80
, F128
.)
We currently cache these types for faster lookup:
class MLIRContextImpl {
/// Cached Type Instances.
Float4E2M1FNType f4E2M1FNTy;
Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
Float8E3M4Type f8E3M4Ty;
Float8E8M0FNUType f8E8M0FNUTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
FloatTF32Type tf32Ty;
Float32Type f32Ty;
Float64Type f64Ty;
Float80Type f80Ty;
Float128Type f128Ty;
IndexType indexTy;
IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
NoneType noneType;
Any thoughts?