When trying to remove the transpose_(a|b)
from matmul
and batch_matmul
, I have hit a snag on the indexing. I really don’t want to create a whole new DSL just for this case, I think we can still get it there and remove those ops (by replacing them with identical representations from their parent ops) before we do that.
Problem description
So, here’s the rub: TensorDef
creates dims as named params, and IndexAttrDef
creates indexes as named params, but they do not represent the actual value even when constant (kinda like SSA), because what that DSL function is doing is building the function, so it needs to be described as a computation.
This is an example of what I’m trying.
def matmul(...
map_a=IndexAttrDef(S.A0, S.A1, default=[0, 1]),
...):
...
# tuple indices must be integers or slices, not SymbolDef
dims_a = (D.m, D.k)
l_a = (dims_a[S.A0], dims_a[S.A1])
...
I can’t use something like int(S.A0)
because that’s just the name of the dimension (ie. literal "Symbol(A0)"
), I need to create an expression that extracts a tuple index from a symbol definition.
Solution idea: Constant list attribute
I’m not well versed in OpDSL (or Python much), but the actual code isn’t that extensive. This should be something like ScalarDef
but as a list, and the generator should be able to interpret them as (generate code for) picking a different SymbolDef
out of the tuple and update the equation accordingly.
I’m trying to avoid having to change the code that transforms symbol naming into affine maps, so get into symbols (in whatever order), which would be identical to the current implementation of the transposed versions today.
Alternative 1: Permute function type
I tried implementing a TransformFn.permute
that can be used as:
C[D.m, D.n] += cast(U, permute(A[D.m, D.k], perm_a) * cast(U, permute(B[D.k, D.m], perm_b))
But that leads to having to implement the function itself, which I don’t know how. Perhaps this would be simpler, but I’m a bit out of my depth here.
Alternative 2: More named ops
We already have matmul_transpose_(a|b)
and batch_matmul_transpose_(a|b)
, I can add more variations for batch_reduce_matmul_transpose_(a|b)
, but there are more than 2 here, I don’t want to do this.
Any better ideas?