Tiling on Linalg on tensors uses scf.for
with destructive updates to represent the tiled computation. For example, a simple 1D operation
#map = affine_map<(d0) -> (d0)>
%1 = linalg.init_tensor [..] : tensor<?xf32>
%2 = linalg.generic {
indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xf32>) :
^bb0(%arg0 : f32, %arg1 : f32) {
%3 = sqrt %arg0 : f32
linalg.yield %3 : f32
} -> tensor<?xf32>
when tiled using tile size %t
is represented as follows
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
%1 = linalg.init_tensor [..] : tensor<?xf32>
%2 = tensor.dim %1, %c0 : tensor<?xf32>
%3 = scf.for %iv = %c to %2 step %t iter_args(%arg0 = %1) {
%4 = affine_min #map1(%iv)[%t, %2]
%5 = tensor.extract_slice %0[%iv] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%6 = tensor.extract_slice %arg0[%iv] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%7 = linalg.generic {
indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%5 : tensor<?xf32>) outs(%6 : tensor<?xf32>) :
^bb0(%arg0 : f32, %arg1 : f32) {
%3 = sqrt %arg0 : f32
linalg.yield %3 : f32
} -> tensor<?xf32>
%8 = tensor.insert_slice %7 into %arg0[0] [%4] [1] : tensor<?xf32> into tensor<?xf32>
scf.yield %8 : tensor<?xf32>
} -> tensor<?xf32>
The tensor.extract_slice %arg0
and the tensor.insert_slice %... into %arg0
represent the destructive update pattern, i.e. each iteration of the loop computes a tile of the computation and inserts it into the current value of the result tensor (%arg0
) and yields the entire tensor. As represented this computation is losing the information that the scf.for
is actually parallel. The serialization is not inherent to the computation, but is purely due to inability to represent parallel computation using scf.for
with tensor
type operands and returns.
The problem becomes even more pronounced with distribution. Currently in Linalg, distribution is part of the tiling transformation, i.e. for a loop that is parallel, block-cyclic distribution is used during tiling to distribute work across threads on a grid. For example, tile + distribute of the original op results in the following IR.
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
#map2 = affine_map<()[s0, s1] -> (s0 * s1>
%1 = linalg.init_tensor [..] : tensor<?xf32>
%2 = tensor.dim %1, %c0 : tensor<?xf32>
%id = "proc_id"() : index
%np = "nprocs"() : index
%offset = affine.apply #map2()[%id, %t]
%step = affine.apply #map2()[%np, %t]
%3 = scf.for %iv = %offset to %2 step %step iter_args(%arg0 = %1) {
....
} -> tensor<?xf32>
This leads to a whole host of problem. The main one is that for the case where all shapes are statically known, all tile sizes used are statically known and the tile size divides the problem size, the code within the tiled loop can be canonicalized to use static shapes as well, which then vectorizes easily (without requiring peeling/padding, etc. given the constraints). Unfortunately the materialization of the distributed loop makes canonicalizing away the affine.min
operation very involved, and without this the sizes arent propagated as static shapes.
There are many potential solutions, each with its own trade-offs. I am listing them here in order of my preference, but really looking for a end-to-end solution (that we can use in IREE effectively). I have a proposal below that is the cleanest IMO, but I am not tied to it. There are potential other solutions as well.
Solution 1 : Extend scf.for
operations.
The distribution that is done during tiling is a really simple transformation. It just a change to the lower-bound and step values using the SSA value for number of processor and processor ID. There is no need to actually materialize the distributed loop. Instead one could change the scf.for
semantics to take these SSA value as two additional operands, so the distributed loop above would be represented above as
%3 = scf.for %iv = %c0 to %ub step %c1 iter_args(...) distributed (id = %id , np = %np) {
...
} -> tensor<?x?xf32>
Here the scf.for
accepts two additional optional operands %id
and %np
to specify the processor ID and number of processors. The presence of these operands also indicates that the loop is parallel and distributable. This also allows easy analysis of the loop while maintaining the semantics that the loop is distributable. For example, for CUDA codegen in IREE, there is a pretty heavy analysis done to figure out if a loop can be removed because its a zero-trip or one-trip loop. Lot of this complication arises from the loop being distributed already. The above form of the loop makes this analysis fairly trivial. The actual distribution can happen pretty late, either just before lowering to LLVM dialect (or SPIR-V dialect).
Solution 2: Making scf.parallel
work on tensors.
AFAIK scf.parallel
only works on memref
types. It could be extended to work on tensors while still maintaining the parallel semantics. The down side I see is that it does not have a way to carry the SSA values to use for distribution as represented above. The semantics of the scf.parallel
loop could also be extended to use the id
and np
, but I am less convinced of the actual value of scf.parallel
over scf.for
if it is just to say some loops are parallel. The same could be achieved by using an attribute on scf.for
for example.
Solution 3: linalg.tiled_loop
For the most part linalg.tiled_loop
actually already has some of the requirements that I am looking for. The way the distribution is represented needs to be tweaked. Instead of using attributes to represent how it is distributed (which is very CUDA specific and an abstraction leak IMO), they should just use optional SSA values. Having followed the linalg.tiled_loop
evolution and after discussions with a few folks, I am still not convinced that linalg.tiled_loop
is paying for the abstraction. I think effective extensions to scf.for
gives us everything we need to represent parallel semantics.
Looking for comments about how to proceed here. Happy to take up any tasks that helps in better representation. I really hope we can go with scf.for
extensions. From experience of using these in IREE that would be the easiest to make work e2e.
cc @nicolasvasilache @albertcohen @ftynse @herhut