We’re currently looking into linalg.batch_reduce_matmul
that can have transposed shapes (ie. non-leading reduction dimension), due to how shapes are packed differently for weights and inputs.
My first attempt was to add an attribute specifying which is the batch dimension. However, upon looking at other examples in opdsl, I realised it could possibly be a transpose argument, that would also remove the need for the other matmul
variations.
So, here’s the simplification plan:
Merge matmul
and matmul_unsigned
It seems the current matmul
can already have a different cast type as an attribute, so front-ends can already lower an unsigned matmul with the attribute.
Why do we have the unsigned version?
Add transpose attribute to matmul
and remove matmul_transpose_(a|b)
We could create a new FnAttrDef
for other Linalg ops (like transpose
), default Identity
, and then add attributes for transposing A and/or B.
Implementation would be something like:
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
t_a=PermutationAttrDef(default=[0, 1]),
t_b=PermutationAttrDef(default=[0, 1]),
...
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
# Note `cast` and `t_(a|b)` are defined above
C[D.m, D.n] += cast(U, t_a(A[D.m, D.k])) * cast(U, t_b(B[D.k, D.n]))
Moving on to batch_reduce_matmul
If the above is reasonable, then we can also add transpose support to batch_reduce_matmul
and support things like (m, b, k) x (b, k, n) -> (m, n)
in addition to the current (b, m, k) x (b, k, n) -> (m, n)
shape by transposing A
as [1, 0, 2]
.
Rationale
On CPUs, for TLB friendly access, we pack the weight matrix with a block transpose ([M][N] -> [BM][bm][BN][bn] -> BM][BN][bm][bn]
) but then we also have to pack the input and then unpack the result to return the output.
If we could have a more flexible contraction, we could avoid packing other tensors at run time.
Alternatives
New op
We could add a new linalg.batch_reduce_matmul_transpose_a
but this would be going in the wrong direction.
Operand specification
Our proposal last year was to not rely on attributes but have other linalg operations to annotate each other. This is my long term goal, but for now, I’m content to de-duplicate implementations of the existing matmul operations and get the batched version working.
Furthermore, to have a true DAG representation, we’d need to introduce a DAG structured matcher API to replace the current usage of the named op variations checking attributes, which is a non-trivial effort.
We want to get there, but via a route that can work before we have the right solution.
Next steps
If this is a good direction, then we should discuss how we implement the transpose in opdsl, but in parallel, we can also start de-duplicating the existing matmul operations.
@ftynse @nicolasvasilache @MaheshRavishankar @mehdi_amini @asiemien @banach-space