Note that the linalg.inplaceable
was introduced to experiment with different strategies dealing with this issue. For example, I had the following problem in the sparse compiler when lowering the following kernel that adds two vectors into sparse, bufferized code.
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = [ "compressed" ]
}>
func @add(%arga: tensor<32xf32, #SparseVector>,
%argx: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait1
ins(%arga: tensor<32xf32, #SparseVector>)
outs(%argx: tensor<32xf32>) {
^bb(%a: f32, %x: f32):
%0 = addf %x, %a : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
return %0 : tensor<32xf32>
}
Without any more information on the in/out tensor argx, the following code must be generated.
%0 = sparse_tensor.pointers %arg0, %c0 : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
%1 = sparse_tensor.indices %arg0, %c0 : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
%2 = sparse_tensor.values %arg0 : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
%3 = memref.buffer_cast %arg1 : memref<32xf32>
%4 = memref.alloc() : memref<32xf32>
linalg.copy(%3, %4) : memref<32xf32>, memref<32xf32>
%5 = memref.load %0[%c0] : memref<?xindex>
%6 = memref.load %0[%c1] : memref<?xindex>
scf.for %arg2 = %5 to %6 step %c1 {
%8 = memref.load %1[%arg2] : memref<?xindex>
%9 = memref.load %4[%8] : memref<32xf32>
%10 = memref.load %2[%arg2] : memref<?xf32>
%11 = addf %9, %10 : f32
memref.store %11, %4[%8] : memref<32xf32>
}
%7 = memref.tensor_load %4 : memref<32xf32>
return %7 : tensor<32xf32>
Here, the linalg.copy
introduces an O(N) operation into something that should be O(nnz) only. By adding the annotation:
func @add(%arga: tensor<32xf32, #SparseVector>,
%argx: tensor<32xf32> {linalg.inplaceable = true}) -> tensor<32xf32> {
...
}
we get the following code, which keeps the update into the dense vector O(nnz).
%0 = sparse_tensor.pointers %arg0, %c0 : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
%1 = sparse_tensor.indices %arg0, %c0 : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
%2 = sparse_tensor.values %arg0 : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
%3 = memref.buffer_cast %arg1 : memref<32xf32>
%4 = memref.load %0[%c0] : memref<?xindex>
%5 = memref.load %0[%c1] : memref<?xindex>
scf.for %arg2 = %4 to %5 step %c1 {
%7 = memref.load %1[%arg2] : memref<?xindex>
%8 = memref.load %3[%7] : memref<32xf32>
%9 = memref.load %2[%arg2] : memref<?xf32>
%10 = addf %8, %9 : f32
memref.store %10, %3[%7] : memref<32xf32>
}
%6 = memref.tensor_load %3 : memref<32xf32>
return %6 : tensor<32xf32>
Eventually, the fully bufferized version looks something like this:
func @add(%arg0: !llvm.ptr<i8>, %arg1: memref<32xf32> {linalg.inplaceable = true}) -> memref<32xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = call @sparsePointers(%arg0, %c0) : (!llvm.ptr<i8>, index) -> memref<?xindex>
%1 = call @sparseIndices(%arg0, %c0) : (!llvm.ptr<i8>, index) -> memref<?xindex>
%2 = call @sparseValuesF32(%arg0) : (!llvm.ptr<i8>) -> memref<?xf32>
%3 = memref.load %0[%c0] : memref<?xindex>
%4 = memref.load %0[%c1] : memref<?xindex>
br ^bb1(%3 : index)
^bb1(%5: index): // 2 preds: ^bb0, ^bb2
%6 = cmpi slt, %5, %4 : index
cond_br %6, ^bb2, ^bb3
^bb2: // pred: ^bb1
%7 = memref.load %1[%5] : memref<?xindex>
%8 = memref.load %arg1[%7] : memref<32xf32>
%9 = memref.load %2[%5] : memref<?xf32>
%10 = addf %8, %9 : f32
memref.store %10, %arg1[%7] : memref<32xf32>
%11 = addi %5, %c1 : index
br ^bb1(%11 : index)
^bb3: // pred: ^bb1
return %arg1 : memref<32xf32>
}