Does middle-end pass need to consider some special type when doing optimization? Or letting back-end to revert the optimization accordingly?

Hi,

We are developing prototypes for Intel Advanced Matrix Extensions (AMX) [1] programing model in Clang and LLVM [2].

We met several cases when the certain type we added are optimized unexpectedly in the middle-end. E.g. optimizing phi + biscast + load:

From

%a = load <256 x i32>, <256 x i32>* %mem, align 64

… …

%b = phi <256 x i32> [ %a, %label1 ], [%someother, %label2]

%c = bitcast <256 x i32> %b to x86_amx

To

%a = bitcast <256 x i32>* %mem to x86_amx*

%b = load x86_amx, x86_amx*, align 64

… …

%c = phi x86_amx [ %b, %label1 ], [%someother, %label2]

To prevent such unexpected transforms, we concretely added the type check in each point of the optimizations.

Roman pointed out the changes are not the right direction [3], and thought it’s bug for backend. While we agreed backend might be able to handle it for the functionality, we think it is better to handle it in the midden-end since they are negative optimizations for AMX.

First, let me put some background here:

@lebedev.ri@gmail.com

Current we see “ if (Ty.isVectorTy()) {…}” is make sense in Mid-End.

Why we can’t see “if (Ty.isX86_AMXTy()){…}” is make sense ?

Just because more targets support the VectorTy, less target (only x86) support the AMXTy ?

The logic is not make sense.

-xiang

I don’t know anything about AMX, but let me give you some pointers (no pun intended).

Regarding pointers, the direction LLVM is taking is to have just 2 pointer types: a data pointer type and a function pointer type. That’s it. That allows us to remove a lot of bitcasts between pointers. You see now that load instructions have an argument with the type, which for now is duplicated with the pointer type, but won’t be as soon as pointer types disappear.

So if you need a special pointer type that can’t be casted to other pointer types, the way to do it in LLVM is with a different address space. Then you can configure how many bits it takes, etc. And more importantly, pointers in that space can’t be casted to another space without using a special instruction (which LLVM optimizers won’t introduce).

FYI by using a different address space, you may lose a few optimizations, because optimizers assume nothing about the non-default address space. We have discussed an API to let folks express assumptions optimizers could make (e.g., is null == (void*)0 ?), but nothing was implemented so far.

Nuno

Hi,

We are developing prototypes for Intel Advanced Matrix Extensions (AMX) [1] programing model in Clang and LLVM [2].
We met several cases when the certain type we added are optimized unexpectedly in the middle-end. E.g. optimizing phi + biscast + load:

From
%a = load <256 x i32>, <256 x i32>* %mem, align 64
… …
%b = phi <256 x i32> [ %a, %label1 ], [%someother, %label2]
%c = bitcast <256 x i32> %b to x86_amx
To
%a = bitcast <256 x i32>* %mem to x86_amx*
%b = load x86_amx, x86_amx*, align 64
… …
%c = phi x86_amx [ %b, %label1 ], [%someother, %label2]

To prevent such unexpected transforms, we concretely added the type check in each point of the optimizations.
Roman pointed out the changes are not the right direction [3], and thought it’s bug for backend. While we agreed backend might be able to handle it for the functionality, we think it is better to handle it in the midden-end since they are negative optimizations for AMX.

First, let me put some background here:
x86_amx* is different from trivial pointers.
The AMX load instruction is much different from other load instructions. It is not only need the memory address but also the shape / stride of the tile register. We did some extra work in the backend to deduce the shape information from the context. We don’t want the pass to add new x86_amx related usage because this will result in the difficulty in deduction. That said bitcasting other pointer types to x86_amx* is not trivial as assumed here.

The problem appears to be that this difference is not modeled or specified in LLVM IR AFAICT. The current LangRef does not appear to specific `x86_amx` to start with. If pointers to `x86_amx` have different semantics than regular LLVM pointer types, using regular LLVM pointer types for pointers to `x86_amx` may not be appropriate. I’ve not followed the previous AMX discussions closely, but it sounds like it may be good to reconsider how x86_amx pointers are modeled in LVM IR.

Also note that `bitcast` is specified as `no-op` (https://llvm.org/docs/LangRef.html#id293) (expect for pointers with different address spaces), but from what you mentioned above this does not match the semantics for `x86_amx*`. It sounds like this is the underlying problem that should be addressed, because trying to update various middle end optimization tot ry to enforce the special semantics does not seem to be a scalable solution.

As Nuno mentioned, you could try and use a separate address space for `x86_amx` pointers to avoid pointer optimizations.

The physical tile registers have more limitations.
No copy instruction between tile registers.
Spilling / reload a tile register is expensive in light of its size is 1024 bytes.
The shapes of tile registers need to be pre-configured before use and all data in tile registers will turn into invalid once re-configured. That said we need to dominate as more tile registers as possible to configure their shapes with one configure instruction, otherwise we need to spill and reload the live registers once we need to re-configure.
The number of tile registers is rather small (only 8) and different shapes cannot be reused.
Based on the limitations, we need to reduce the use / live range of tile registers. But optimizations may increase the opportunity of the use. So even we can handle some combined operation for AMX type, we still prefer to prevent it from the beginning. Unless we can totally roll back the optimization. Which is also not a good solution in my opinion.
For more information, please refer to discussion in [3].
For other optimization points, please refer [4][5].

I think the main controversy from Roman is if middle-end pass should consider some special type when doing optimization. I tend to let middle-end do the type check on account of the peculiarity of AMX type. But I’m not sure if we have precedent to handle the similar issue in other targets. I’m open and glad to do it either way so long as we have an elegant solution.
Any suggestions are welcome.

IIUC the main problem is not that middle-end passes perform or not perform optimizations based on certain types. To me it sounds like the actual problem is that pointers to `x86_amx` do not behave like regular LLVM IR pointers and you are trying to enforce extra restrictions for bit casts.

Cheers,
Florian

Thank Florian. I agree with you that pointers to x86_amx have different semantics than regular LLVM pointer types. First the x86_amx pointer point to a 2D array of a big matrix. The data of each row is contiguous, but the data on contiguous row is not contiguous in memory. Below picture shows the x86_amx load semantics. We need another operand stride to describe the stride of each rows. So the semantics for “load <256xi32>” and “load x86_amx” is different. Because “load <256 x i32> assume the memory is contiguous and load a flat vector.

You also mention that there is no documentation of x86_amx in the langref. I’d like to add x86_amx to the document. Is there any process to document for a type?

Thanks

Yuanke

Err…are you saying this is the expected semantics of a “load x86_amx” operation today? That doesn’t make much sense…Surely a strided-load operation should be spelled llvm.matrix.column.major.load in the IR, not load?

I mean transforming from “load <256 x i32>” to “load x86_amx” is not invalid because x86_amx represent a tile and “load x86_amx*” doesn’t express its semantics without a stride. Now it looks to me “load x86_amx*” is invalid.

Since the x86_amx type has a fixed size of 1024, I would expect %v = load x86_amx, x86_amx* %ptr to load 1024 bytes of contiguous memory starting at %ptr – I don’t see why this should be invalid?

But x86_amx represent a tile. The semantics of hardware instruction tileloadd is something like ‘llvm.matrix.row.major.load’. How do we lower %v = load x86_amx, x86_amx* %ptr to tileloadd?

Yes, since all operations of amx data can only use amx instructions, So we use x86_amx type in mid-end/back-end to separate them from normal llvm IR instructions.

So let me come to the beginning:

I think it is OK to use the “x86_amx type” in mid-end.

image001.jpg

Why is that harder than lowering a load <256 x i32> and then bitcast to x86_amx?

E.g., I see there is in llvm/lib/Target/X86/X86LowerAMXType.cpp a transform:

%src = load <256 x i32>, <256 x i32>* %addr, align 64
%2 = bitcast <256 x i32> %src to x86_amx

%2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %addr, i64 %stride64)

Isn’t it equivalent, then, to do:

%2 = load x86_amx, x86_amx* %addr, align 64

%2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %addr, i64 %stride64)

Hi James,

Thank you for taking the time to deep dive the issue. It is very constructive. I agree we can transform “load x86_amx*” to amx load intrinsic. But it seems we need more effort to do the transform than preventing generate “load x86_amx*”. I can support transforming “load x86_amx*” to amx load intrinsic if people like this approach.

I also think Florian raise a good question. What the semantics about “load x86_amx*”. Is it different semantics than regular LLVM pointer types? What’s your opinions on it?

Thanks

Yuanke

Yes, that is equivalent, but at Front end, we don’t have existed type to express AMX type.

The “AMX type” in c/c++ language is implied by the following structure:

typedef int tile1024i attribute((vector_size(1024), aligned(64)));

typedef struct __tile1024i_str {

const unsigned short row;

const unsigned short col;

tile1024i tile;

} __tile1024i

So we handle the “%src = load <256 x i32>, <256 x i32>* %addr, align 64 %2 = bitcast <256 x i32> %src to x86_amx”

not “%2 = load x86_amx, x86_amx* %addr, align 64”

image002.jpg

I write a patch (https://reviews.llvm.org/D93788) to transform the load/store x86_amx* to amx intrinsics. The effort is much more than disable the bitcast from load/store <256 x i32>* to load/store x86_amx*.

I also think the pointee type shouldn’t matter; my impression was that ty* and ty’* should be treated equivalently and bitcasting between these should not have any side effects.

But, when it is used by load, which receives a type for interpretation of the loaded value, I don’t think it’s safe to convert load ty to load ty’ with the same bit width in general.
A relevant bug in gcc: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=58416 , the transformation is also happening in LLVM: https://bugs.llvm.org/show_bug.cgi?id=45152

From the points earlier, it sounds like you’d need to change `load` semantics for `x86_amx` to load blocks of data with gaps in between them? I am not sure if that’s a good idea, as there are plenty of places in LLVM that make use of that assumption I think (e.g. the code reasoning about memory locations). I’d expect lots of places would need updating and until everything is updated there will plenty of places that get this subtly wrong. This doesn’t sound scalable.

Cheers,
Florian

Are the bitcasts introduced by the frontend? If you need different semantics for loading from an `x86_amx` pointer, could the frontend generate a call to an intrinsic instead?

Cheers,
Florian

Yes, bitcasts introduced by the frontend call amx intrinsics. We use vector to represent 2D amx tile in C language, on the other hand we don’t want to mix our amx tile to other vector operation, so x86_amx is introduced to isolate amx intrinsics from normal vector operation. The bitcast is to monitor that a normal vector is passed to amx intrinsics. In below example, we need to transform the bitcast to a vector store and an amx load intrinsic. The x86_amx* is unexpected at the beginning, but in the pass of InstrCombine the middle-end generate the x86_amx pointer.

define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) {

; CHECK-LABEL: @test_src_add(

; CHECK-NEXT: entry:

; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64

; CHECK-NEXT: [[ADD:%.]] = add <256 x i32> [[Y:%.]], [[X:%.*]]

; CHECK-NEXT: [[TMP1:%.]] = bitcast <256 x i32> [[TMP0]] to i8*

; CHECK-NEXT: store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024

; CHECK-NEXT: [[TMP2:%.]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.]], i16 [[C:%.]], i8 [[TMP1]], i64 64)

; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP2]])

; CHECK-NEXT: ret void

;

entry:

%add = add <256 x i32> %y, %x

%t = bitcast <256 x i32> %add to x86_amx

call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t)

ret void

}

Thanks

Yuanke

Hi Florian,

Sorry, I didn’t understand your question.

If we can’t prevent load x86_amx* being generated, we need to transform load x86_amx* to llvm.x86.tileloadd64.internal() with shape propagation. The source of load/store instruction is generated by front-end because in C language we define our tile type as “typedef int tile1024i attribute((vector_size(1024), aligned(64)));”. It is a vector type. Actually we hope all operation about amx tile is through amx intrinsics, and the data exchange to other operation is through memory. But front-end generate load/store <256 x i32>* instruction instead of llvm.x86.tileloadd64.internal() or llvm.x86.tilestored64.internal().

What assumption does LLVM make use of?

Thanks

Yuanke

Ok I think I understand the issue better now. IIUC you use `bitcast` in the frontend to convert between regular vector and the AMX values?

This doesn’t really match the way `bitcast` is defined (as discussed earlier) and this mismatch seems to be the source of the issues. I don’t think you should use `bitcast`s that way and instead adjust the frontend to emit different code for the conversion between vector and amx values (e.g. use an intrinsic to convert between vector and amx values; the intrinsic can be directly lowered to the conversion code).

I think there are at least two ways forward:

1. Avoid using bitcasts for the conversion in the frontend.
2. Try & define the semantics of bitcast/load for AMX types, such that the transformations you want to exclude in instcombine are illegal.

If you decide to go with 2., you probably will have to make a convincing argument why this is the right thing to do and why other alternatives do not work, because it means that certain general transformations that are legal at the moment become illegal for certain types (which is illustrated by the instcombine patches you mentioned)

Cheers.
Florian