I would like to propose adding a concatenate op to Linalg. I’ve been prototyping on concatenate op support in IREE for a while. When adapting the patterns to tensors world, I hit the limit of Linalg again. There are some workarounds, but the promising solution to me is to add a concatenate op. There was a discussion long time ago. I put some context from it and added more thought below.
Option 1: lowering the op to an indexed_generic op
This is the first prototype I had. The idea is to lower a concatenate op to an indexed_generic op, and use the indices to yield correct value. With this approach, extra dimensions/loops are needed because of loop bound inference. The main problem is that you have 2 tensors (2, 2) and (3, 2) from which the generic op needs to infer the bounds for d0 and d1.
The first tensor will tell you that (d0, d1) is in the range [0, 2) x [0, 2).
The second tensor will tell you that (d0, d1) is in the range [2, 5) x [0, 2).
According to the first property of linalg generic, these ranges don’t match and this is undefined behavior: 'linalg' Dialect - MLIR.
A snippet of lowering concatenating memref<2x2xi32> and memref<3x2xi32> to
memref<5x2xi32> to scf dialect.
scf.for %arg3 = %c0 to %c2 step %c1 {
scf.for %arg4 = %c0 to %c5 step %c1 {
scf.for %arg5 = %c0 to %c2 step %c1 {
scf.for %arg6 = %c0 to %c3 step %c1 {
%0 = load %arg1[%arg5, %arg3] : memref<2x2xi32>
%1 = load %arg2[%arg6, %arg3] : memref<3x2xi32>
%2 = load %arg0[%arg4, %arg3] : memref<5x2xi32>
%3 = subi %arg4, %c0 : index
%4 = cmpi "eq", %3, %arg5 : index
%5 = cmpi "sge", %arg4, %c0 : index
%6 = cmpi "slt", %arg4, %c2 : index
%7 = and %5, %6 : i1
%8 = select %7, %0, %0 : i32
%9 = subi %arg4, %c2 : index
%10 = cmpi "eq", %9, %arg6 : index
%11 = or %4, %10 : i1
%12 = cmpi "sge", %arg4, %c2 : index
%13 = cmpi "slt", %arg4, %c5 : index
%14 = and %12, %13 : i1
%15 = select %14, %1, %8 : i32
%16 = select %11, %15, %2 : i32
store %16, %arg0[%arg4, %arg3] : memref<5x2xi32>
}
}
}
}
This actually is just a workaround to express concat
in Linalg, but it is not the way to go. It will generate many inefficient loops.
Option 2 – lowering the op to bunch of subviews/subtensors + copy/subtensor_insert
The current solution in IREE is to lower the op to subviews + copies in buffer’s world. I have a prototype to have similar behavior in tensor’s world by using subtensors and subtensor_insert. The idea is to create a linalg.fill
op, and operate on it.
E.g.,
Input:
func @concatenate(%0: tensor<2x2xi32>, %1: tensor<2x3xi32>) -> tensor<2x5xi32> {
%2 = "mhlo.concatenate"(%0, %1) {
dimension = 1
} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
return %2 : tensor<2x5xi32>
}
Output:
func @concatenate(%0: tensor<2x2xi32>, %1: tensor<2x3xi32>) -> tensor<2x5xi32> {
%init_tensor = linalg.init_tensor [2, 5] : tensor<2x5xi32>
%filled_tensor = linalg.fill(%init_tensor, 0) : tensor<2x5xi32>
%sub1 = subtensor_insert %0 into %filled_tensor[0, 0] [2, 2] [1, 1] : tensor<2x2xi32> into tensor<?x?xi32>
%sub2 = subtensor_insert %1 into %filled_tensor[0, 2], [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?x?i32>
return %filled_tensor : tensor<2x5xi32>
}
The issue here that there are no uses for %sub1
and %sub2
, so they will be killed in DCE. So far, I don’t have a workable prototype with this option.
Option 3 – adding a linalg.concat op
Adding a special Linalg op looks most promising to me.
Input in tensor world:
%1 = add (...) : tensor< a x ...>
%2 = add (...) : tensor< b x ...>
%3 = linalg.concat(%1, %2) : tensor<(a + b) x ...>
Proposed lowering of a special linalg.concat op that is not an index generic.
%3 = alloc (...) : memref<(a + b) x ...>
%1 = subview %3[][][] : memref<a x ...>
%2 = subview %3[][][] : memref<b x ...>
add(..., %1)
add(..., %2)
The issue is with fusion: the split between %1
and %2
is propagated into anything the adds fuse into until fusion stops. Also, fusion is not a real fusion: we end up with 2 “fused columns”: anything that depends on %1 and anything that depends on %2.
However, having a linalg.concat operation makes sense to me because it fills the missing part in Linalg. There are two version of linalg.concat
op in my mind.
Simple linalg.concat op
A linalg.concat
op takes various tensors and an index, then produce a concatenated tensor.
E.g.,
%0 = linalg ... : tensor<2x2xi32>
%1 = linalg ... : tensor<2x3xi32>
%c1 = constant 1 : index
%2 = linalg.concat %0, %1 along dim %c1 : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
In bufferization, you can allocate a buffer for the result and copy the values from operands to the buffer.
%0 = linalg ... : memref<2x2xi32>
%1 = linalg ... : memref<2x3xi32>
%buf = alloc (...) : memref<2x5xi32>
%sub1 = subview %buf [0, 0] [2, 2] [1, 1] : memref<...>
linalg.copy %0, %sub1 : memref<...>
%sub2 = subview %buf [0, 2] [2, 3] [1, 1] : memref<...>
linalg.copy %1, %sub2 : memref<...>
Non-simple linalg.concat op
Essentially a linalg.concat operation is a “collection” of N linalg operations, where N is the number of operations to the concat, each producing a single result tensor. In it simplest form teach of the concat-ed operation will just be a trivial op that just returns the input.
%result = linalg.concat [%0], [%1], [%2], .... [%n] {
%r0 = linalg... %0
%r1 = linalg... %1
...
%rn = linalg... %n
} : tensor<....>
The square brackets around each operand in the top-level linalg.concat operation is to indicate the arguments that are to be “forwarded” to each of the individual linalg operations. So the concat operation takes N
list of values.
To fuse this concat operation with its producers, is easy. Lets say %n
is produced by another linalg operations as follows
%n = linalg %a, %b
%result = linalg.concat ... [%n] {
...
%rn = linalg... %n
}
after fusion you get
%result = linalg.concat .... [%a, %b] {
...
%rn = linalg... %a, %b
}
The new Linalg operation producing %rn
is just obtained by fusing the operation producing %n
and the old operation producing %rn
.
When converting to buffers, you can allocate a buffer for the result of the outer linalg.concat
operation, and then split it based on the concat specification to get the result buffer for the inner linalg operations. So you get a sequence of operations which computes the linalg.concat
“in-place”. Effectively this operation would not exist in buffer world.
(@nicolasvasilache @MaheshRavishankar @asaadaldien @antiagainst @ThomasRaoux for visibility)