[RFC] Implementing a matrix multiply & vector operation fusing optimization

Perhaps the easiest way to explain the goal of this discussion is by providing a motivating example:

// We have three matrices, A, B, and C.

// At a higher level of abstractin, the pseudocode would look like this:
A_add = eltwise_mul(A, 0.5f);
C = matmul(A, B);
C = relu(C);

With the advent of AI, there are many architectures dedicated to accelerating dense matrix multiply, such as Intel AMX, Arm SME, Huawei Ascend and many other accelerators. They often also come with a vector ISA attached, so that often times it is better to fuse elementwise operations on the input and output of the matrix multiply operations (pre- and post-processing) together:

// Innermost loop (general idea):
for (tiled_k...)
	%a = load %A[offset_a(tiled_k...)] : vector<8xf32>
	%scaled_a = vector.mulf %a, 0.5f : vector<8xf32>
	%b = load %B[offset_b(tiled_k...)] : vector<8xf32>
	%c = vector.dot (or outer product, or matrix-multiply, etc.) %a, %b : f32
	%result = vector.relu %c : f32
	store %result, %C[offset_c(tiled_k...)]

The difficult part regarding this in the existing infrastructure is that once the operations fuse and gets lowered into (affine) loops, it becomes very hard for us to recognize what kind of operation it is and to vectorize using both the matrix and vector accelerators. We think that it would make a lot of sense for there to be an intermediate step/operation, where matrix multiply ops are able to be grouped with elementwise pre- and post-processing operations.

Our initial idea is to express this using the linalg dialect somehow, but we have not determined a specific syntax yet, but something similar to this:

linalg(?).fused_matmul ins(%A, %B) outs (%C) {
	// pre-processing
	linalg.generic(%A){...}
	linalg.generic(%B){...}
	// post processing
	linalg.generic(%C){...}
}

Hence our RFC: We wanted to start a discussion with the community around this topic. How can we enforce the internal operations to be elementwise? Do we want to restrict ourselves to elementwise operations for pre and post-processing? Is this easier using a linalg.generic? Or maybe some other interface would be more appropriate?

We can probably leave the detailed analysis and lower procedure for a later discussion, here we just want to get feedback on the initial idea to see if we are starting off on the right foot. Thanks.

1 Like

Just use 3 linalg ops along with the tiling and fusion transformations: they preserve the high level nature of the linalg ops after fusion in a containing scf.for or scf.foreach_thread.

This part of the system is designed precisely for the use case you describe.

I can elaborate more next week if needed.

Please do!