[RFC] Should the OCP microscaling float scalars be added to APFloat and FloatType?

Background

The Open Compute Project has defined a new standard for block-scaled microfloat (MX) data types.

This format defines datatypes that consist of a block of micro-floats (such as the 4-bit E2M1 type or the 6-bit types E3M2 and E2M3) along with a shared scale, which holds 8 exponent bits and no mantissa.

The intended usage of these microfloats is as part of a block-scaled type:blocks of K (with K=32 being required by the specification) microfloats along with their shared scale. This is intended to provide efficient matrix multiplication and to allow the compressed storage of tensors, most typically the weight tensors in machine learning models.

Since there’s currently software emulation for using these types in PyTorch and that hardware vendors will likely be adding support for these MX types in the future (on account of having signed on to the spec), it’ll be useful to have these types representable in MLIR’s core library in the interests of interoperability.

The types

The components of a microscale block are somewhat like floats, in that they have sign, exponent, and mantissa bits. However, due to their small size, the individual components don’t have infinities or NaNs. When a NaN is needed, it is applied to entire block of microfloats by setting the shared scale to 0xff.

The alternatives

Between their very limited range, sub-byte length, lack of special values, and the fact that the OCP specification doesn’t comtemplate operations on individual microfloats, it’s not clear that individual microfloat types (like a hypothetical FloatE2M1FN) should be added to MLIR’s FloatType hierarchy (which implies their addition to APFloat as a prerequisite matter).

However, adding those scalars to FloatType would enable microfloats to be handled with much of MLIR’s existing infrastructure. For example, a block of 4-bit floats could be a vector<32 x fE2M1FN> instead of needing a custom type.

An alternative approach would be to define a new microscaling dialect and add custom types like !microscaling.fe2m1 and !microscaling.block<32xfe2m1> and !microscaling.scale. When combined with a sufficiently general bitcast operation, this would allow defining operations on microscale types (which, at the end of the day, will still lower to an appropriate-width integer (or vector of them) since there are no plans to add these microfloats to LLVM or SPIR-V’s type systems to my knowledge).

However, because these custom types would sit outside of the float type hierarchy, they wouldn’t be permissible in many high-level dialects like Tosa, creating substantial additional friction around the process of generating a high-level description for model fragments that are meant to use microscaled floats on their inputs or outputs.

I’m not seeing a clear best path forward on this design question, so I’m writing to get the opinions of other people who’ll be hooking up microfloat support in MLIR at some future time.

2 Likes

Thanks for bringing this up. We’ve started to look at microscaling format support for TOSA. It’s still early, so we don’t have a specific proposal yet, but it would be great to have something aligned across dialects.

In terms of the TOSA operator set, our current idea is to only allow the microscaling formats as inputs/outputs to the tosa.cast operator. That way you would cast them to a standard FP type to do your calculations and cast back to a microscaling format afterward. I think that would fit better conceptually with your second approach, but would also like to hear other opinions.

@eric-k I’m in agreement with the mechanism in Tosa being casts, since that’s something anyone can pattern-match on if they’ve got a hardware-specific microscaling convolution or matmul. I’d argue you’d want to, right now, ensure you can cast to/from the 8-bit float types (f8E5M2, f8E4M3FN, and so on) in Tosa, since 8-bit float matrix multiplication acceleration already exists.

Now, that being said, the cast operation for microscaled floats is a bit weird, compared to arith.extf and arith.truncf. I’ve got a preliminary sketch of a microscaling dialect bouncing around my machine, but I haven’t posted it because I wanted to solve this design issue first. Relevantly, I’ve been considering the cast operations like so

%block, %scale = microscaling.trunc %floats : vector<Kxf32> to !block<fE3M2, K>, !scale
%floats = microscaling.ext %block, %scale : <fE2M1, K>, !scale to vector<Kxf32>

(though that f32 might be f16 or something like that)

This is operating on the assumption that the computation is being passed a tensor of blocks and a tensor of scales, which is one solution for the cache locality vs. alignment tradeoff that comes with attaching 8 bits to a 128-256 bit packed vector. In other words, at a tensor level, we’d see

func.func @ext_microscaled_soa(%blocks: tensor<Nx!microscaling.block<fE5M2, 32>>, %scales: tensor<Nx!microscaling.scale>) -> tensor<Nx32xf32> {
  %ret = tosa.cast %blocks, %scales : ...
  return %ret : tensor<Nx32xf32>
}

There’s an alternative storage scheme that looks like

func.func @ext_microscaled_aos(%data : tensor<Nx!microscaling.packed_block_scale<fE5M2, 32>>) -> tensor<Nx32xf32> {
  %ret = tosa.cast %data : ...
  return %ret : tensor<Nx32xf32>
}

Which of these schemes makes more sense is, as far as I can tell, architecture and problem dependent, so we’d want to support both options.

Now, if we’re going this route, it’ll probably be useful to have %packed_block_scales = microscaling.pack %blocks, %scales and %blocks, %scales = microscaling.unpack %packed_block_scales for getting everything to fit into APIs. (Those operations would work on the tensor level, and, ideally, would be folded away during codegen)

There’s also something like

func.func(%floats : tensor<Nx32x!microscaling.elem<fE2M1>>, %scales tensor<Nx!microscaling.scale>>) -> tensor<Nx32xf32> {
  %ret = tosa.cast %floats, %scales : ...
  return %ret : tensor<Nx32xf32>
}

but that has the problem that you’ve now got a tensor with sub-byte elements, meaning that you can’t index into it post-bufferization and expect that to work without a bunch of special-casing, so I’m not particularly enthusiastic about this approach.

I hope this is making some sort of sense.

Thanks. One of the problems I’ve been having while parsing the MX specs is that they say nothing about layouts. I’m having a bit of trouble connecting a concrete design to such an abstract description.

But if I were imagining how this might work out and borrowing from other spaces, it kind of has to break down to some form of planar or interleaved/packed formats, and it seems to me that at the level we operate here, we will need to have a way to represent both.

I find it much easier to think of these things from a bottom up perspective, and there are already a ton of priors. Particularly, this looks a lot like some of the structs that llama.cpp uses for its various numeric formats. And that connects to some experiments that were done last year to literally use llvm.struct as a tensor element type. I’m not saying we should literally do that, but the result was that such a representation composed reasonably well.

It seems to me there might be 3-ish ingredients:

  1. At the very high level, potentially some form of encoding attributes/types and corresponding constant attributes/ops for representation of logical literals prior to imbuing them with layout.
  2. MX structured types and various pack/unpack ops to transform them between planar and packed forms.
  3. Type interfaces for struct-like types and work to make those operate as element types for tensor and memref. Possibly make vector implement the interface.

Thinking about a generic, non hardware accelerated CPU lowering pipeline should inform what pieces are needed.

One issue I’ve seen in the past is for people, when faced with an abstract type like this, to assume that nothing can be assumed about layout and how it decomposes. In reality, the universe of how it decomposes, while platform specific, is still bounded by the usual concepts of locality and alignment. Ie. Arrangements that are non-sensical can be ignored vs joined over, and unless if I miss my guess, there will end up being a manageable number of permutations. The hint I’ve found that you’re in one of these non sensical areas is that you end up with unaligned sub-byte types that are not blocked in a coherent way. I’m yet to see that arise in nature in a way that is legal for any problem or platform. The answer there is always to propagate the constraint upward towards the user level programming model, in my experience.

Reading back over your message, I think we’re on a similar page, but it helps me nonetheless to write out my thoughts.

References:

And to answer the question posed at the top, I don’t think these are related to APFloat or FloatType. I think that from a code generation perspective and for targeting lower level hardware intrinsics they are something new for MLIR and more closely related to packing/blocking. Would be a good opportunity to connect those concepts up properly, as they also come up frequently as we deal with various implementation defined sub byte formats.

As for the basic scalar types themselves, I think that is an option we should keep an open mind about. The hint that we should define them in APFloat is if we end up starting to use them as scalars or basic element types in constants. I could see such a situation arising naturally as part of the rest of the design, and the advantage to defining them properly is that you get basic parse/print/typing and arithmetic emulation suitable for folding and such. If we find ourselves there, it makes a lot of sense to add them as real types, and I’ve found that a much better option than trying to define such things another way.