Now that sparse tensor types are first-class citizens and the sparse compiler is really taking shape, see for example this previous thread and the Sparse Compiler pre-print paper, it is time to make sure all other compiler optimizations compose well with sparse tensors, since proper composition is one of the major themes in the Composable and Modular Code Generation in MLIR pre-print paper.
Mostly, the composition should be completely transparent (i.e., kernels on dense or sparse tensors take exactly the same path through the compiler). Revision D119971 demonstrates such trivial compositions, since it places a set of linalg rewriting rules upstream of the sparse compiler with only minimum changes.
In some cases, however, optimizations only make sense in the context of sparse tensors, even though the rewriting would be applicable to both dense and sparse tensors alike. Revision D120429 is a first example of such an optimization that fuses two subsequent kernels by distributing a multiplication over additions (relying on arithmetic equalities that may not always hold for floating-point computations). This rewriting would be undesirable in the dense case, but may make a lot of sense for sparse tensor operands, since the resulting kernel has a potential lower asymptotic complexity.
For example, expressing SDDMM in PyTACO as follows, where S and X are sparse, and A, B, dense
X[i,j] = S[i,j] * A[i,k] * B[k,j]
yields a single Linalg op that computes the expression sum(k, S[i,j] * A[i,k] * B[k,j])
. However, when expressed as follows
X[i,j] = A[i,k] * B[k,j] * S[i,j]
the semantics dictate that subexpression sum(k, A[i,k] * B[k,j]) * S[i,j]
is expressed using two Linalg ops.
T[i,j] = A[i,k] * B[k,j]
X[i,j] = T[i,j] * S[i,j]
Without the fusion described above, the first kernel would compute the full dense matrix times matrix multiplication prior to doing the element-wise sampling. This would have a completely wrong asymptotic complexity. By fusing the ops, we again see that we only need to compute dot products that are not nullified.