(1) The AMX dialect in MLIR itself is very concise, it consists of just a few new vector operations.
All these operations work on concepts familiar within the MLIR framework: 2-d vectors and memrefs. For example, clearing a tile within an enveloping buffer can be done as follows.
The idiomatic TMUL operations look as follows.
The dialect verifies that the types and shapes are actually supported by AMX.
(2) The AMX dialect is lowered into an LLVMAMX dialect, which is closer to the compiler-oriented “internal” intrinsics of LLVM IR. The lowering takes care of some tedious details, such as providing tile parameters, stride computations, and instruction selection. The “amx.multf” is lowered to ‘tdpbf16ps’ and “amx.multi” to one of the ‘tdpbssd’, ‘tdpbsud’, ‘tdpbusd’, or ‘tdpbuud’ intrinsics. All this is completely transparent to the higher levels, though.
The example above lowers to the following.
(3) This dialect is further lowered into LLVM IR dialect. At this point, the x86_amx type comes into play.
(4) The LLVM IR dialect is eventually handed off the LLVM. The backend converts the intrinsics into instructions and adds the necessary tile configuration instructions.
A sample configuration set up and actual kernel instructions are illustrated below.
I added several integration tests to make sure everything works as expected (on the emulator).
For example, storing a 16x16 vector into e.g. a 19x19 buffer accounts for the proper stride for the enveloping sizes:
( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
Likewise, a matrix multiply of pairwise bf16 into f32 works as expected.
Please let me know if you have any feedback.