Tensors of complex numbers, bufferization and lowering

Hi all,

I am looking at using tensors where the element type is a complex number (using the complex dialect), but I have noticed that a lot of passes are not supported for this dialect (such as vectorization and bufferization) and so I want to work to enable them.

Unlike the current lowering pass of the complex dialect, I would prefer my tensors of shape <MxNxcomplex<T>> to be turned to <MxKxT> where K = 2*N, as I wish to represent complex numbers just as two consecutive ‘T’ (for my case: f32 or f64) and not to lower to a llvm.struct containing 2 elements.

However I am struggling to find the best way and the best “level” to do this. I have at the moment managed to somewhat extend vectorization and the vector type to accept elements of type complex, but the bufferization does not play well with this, and further lowering is still an open problem.

I was wondering if a good level to transform complex into consecutive floats would be just before bufferization, when things are still at the tensor level? I would thus have a pass that would transform something like:

%zero_complex= complex.constant [0.0 : f32, 0.0 : f32] : complex<f32>
%a = vector.transfer_read %A[%x, %y], %zero_complex {in_bounds = [true, true]} : tensor<16x16xcomplex<f32>>, vector<8x1xcomplex<f32>>

Into this:

%zero_f32 = arith.constant 0.0 : f32
%a = vector.transfer_read %A[%x, %y], %zero_f32 {in_bounds = [true, true]} : tensor<16x32xf32>, vector<8x2xf32>

And then let the compiler reduce the code further.

In your expert opinions (PSA: I am very not an expert of MLIR myself), would this be a good solution? The other option would be to extend memrefs so that they can hold complex values, but it seems strange to me as I think memrefs are sort of low-level so having complex types seems unfit for memref, right?

Anyway, I am very unsure of what the best move for all of this is (if there is one), so any suggestion/comment/help would be greatly appreciated…

Many thanks in advance.

1 Like

I am not fully aware of the history on this topic, but I do not believe that the complex type was ever conceived of to be a “physical type” that has a concrete realization – or to be more precise, it has several possible physical representations that folks want in different circumstances.

In IREE, we don’t represent them at all, preferring to transform them at the boundaries and then decomposing internally. But this approach would also work with the layout you are looking for.

We’ve always meant to generalize this for upstream, but there are some interplays with frontend dialects and broadcasting that we never untangled.

When the above was written, the complex dialect didn’t exist. It would be nice if we had patterns+a pass to decompose operations in that dialect.

As with anything in IREE, if you see something general that would be good upstream, just ask. If we don’t have time, we’re happy to authorize it being contributed to the llvm project.

Vector is similar there. Vector dialect/type is meant for mapping directly to device types. So the expectation is they map directly and without loss to HW. Which complex don’t (as Stella also points out).

Having this transformation directly at tensor level (before vector) makes most sense to me. One can do the reshaping there and then the remainder of lowerings should work without needing to extend vectorization, bufferization or the like. The biggest issue would probably be the operations on complex that are actually complex (vs complex just being a tuple) and those would probably need dedicated ops that understand you have mapped them such. But I think that would be the most consistent and run into the least issues.

Certainly this depends on what you mean by “HW”, and I would disagree, even if you restrict yourself to “common” hardware. What is lost by (for example) representing the 16-bit real and imaginary parts of a complex number that are placed within the same 32-bit register in a GPU warp?

(For what it’s worth, I think that the Vector dialect should support complex numbers.)

I think that from a purely “deep learning frameworks” perspective, complex neural networks are uncommon, and accelerator libraries (typically) don’t provide complex implementations of all functions. So the frameworks lower to real/imaginary components (or polar coordinates). That may or may not be optimal from a performance perspective, depending on the situation, but at least “it works.” If complex neural networks were to become commonplace, I would think there might be value in, say, retaining the ability to express and optimize complex activation functions at the MLIR level.

It seems to me that complex numbers are outside the scope of what most people are currently using MLIR for, as opposed to not mapping to the hardware, or because it is an aggregate type instead of a primitive (as was suggested here). But reasonable people can disagree :wink: .

I think this is a reasonable viewpoint: it just isn’t how the complex type is defined. If there is a reason for the lower levels to have a notion of specific complex types with concrete layouts, I would be supportive of defining those types (but the specifics of how to do so is a bit outside of my direct experience).

Thanks so much @stellaraccident , @jpienaar and @jfurtek for your replies. If I understand well, the best approach would be either to do as it is done in IREE and have a pass that gets rid of complex numbers from the start, or to do it before vectorization, at the tensor level, right?

As mentionned by @jpienaar , my only concern is that by doing so I might lose the information that I am working on complex, and when post-vectorization I end up with, say, vector.contract, the compiler would not know at that stage that it’s a contraction on complex types.

EDIT: As for the fact that supporting complex in the vector dialect would or would not be fit for purpose, I discussed this internally with @giuseros and his point is that vectors don’t really map to hardware already, since you can have something like vector<256x8xf32> and these need to be unrolled/legalized to map to hardware. We’re interested to know other people’s take on this!

Some related work with a previous attempt to add complex types to LLVM: Complex proposal v3 + roundtable agenda

I think the characteristic we are looking for is “primitive type with concrete in-memory representation” vs aggregate. The gating factor for adding them would be exactly what you are saying: the need to preserve the primitive type and supporting ops at the vector level so that it can be mapped to hardware.

I don’t believe we have a trait or anything that defines this characteristic (most of the primitive types preceded type interfaces and were effectively invented at the dawn of time, project wise), but we possibly should if expanding it.

If there was a strong case to add, say, a complex_real_f16_imag_f16 as a primitive because a machine uses that, I could see putting that in the complex dialect and extending the complex dialect ops to consider it a valid complex type.

This is just one way to approach it. I’m pretty sure that a discussion would ensue if proposed and that would prompt other options. This stuff is need driven, and if there is a lower level programming model that has latched on to some specific layouts, that may rise to the level of a new type.

Yes, the information loss is the tradeoff. As a general purpose compiler, it makes sense for something like IREE to support this as a fallback path, but if there was ever an advantage to preserving the complex types and ops to a lower level (to mate with hardware or a low level programming model), we would be looking at new types.