I noticed that commit 34d65e81e831e10499928dc1700151aa678e1894 changed the way the isMemoryWrite
function defined in mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp:626
(now renamed “defaultResultBufferizesToMemoryWrite
”) behaves, by adding a third case where the function checks if any operand bufferizes to a memory write in a reverse use-def chain.
Because of this, a lot of code now contains memref.copy
operations at the bufferization stage, when it did not before.
Here’s a bogus example which triggers the issue:
module {
func.func @test(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?x1x8xf32> {bufferization.writable = true}) -> tensor<?x?x1x8xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c9 = arith.constant 9 : index
%c256 = arith.constant 256 : index
%0 = tensor.empty() : tensor<256x256x1x8xf32>
%cast = tensor.cast %0 : tensor<256x256x1x8xf32> to tensor<?x?x1x8xf32>
%1 = scf.for %arg2 = %c0 to %c9 step %c1 iter_args(%arg3 = %arg1) -> (tensor<?x?x1x8xf32>) {
%2 = scf.for %arg4 = %c0 to %c9 step %c1 iter_args(%arg5 = %cast) -> (tensor<?x?x1x8xf32>) {
%3 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, false]} : tensor<1x?xf32>, vector<1x8xf32>
%4 = vector.transfer_write %3, %arg5[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x8xf32>, tensor<?x?x1x8xf32>
scf.yield %4 : tensor<?x?x1x8xf32>
}
scf.yield %2 : tensor<?x?x1x8xf32>
}
return %1 : tensor<?x?x1x8xf32>
}
}
This used to bufferize to clean code, with no memref.copy as there is no write conflicts. But now, running mlir-opt --empty-tensor-to-alloc-tensor -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries"
returns the following:
module {
func.func @test(%arg0: memref<1x?xf32, strided<[?, ?], offset: ?>>, %arg1: memref<?x?x1x8xf32, strided<[?, ?, ?, ?], offset: ?>>) -> memref<?x?x1x8xf32, strided<[?, ?, ?, ?], offset: ?>> {
%c9 = arith.constant 9 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() {alignment = 64 : i64} : memref<256x256x1x8xf32>
%0 = scf.for %arg2 = %c0 to %c9 step %c1 iter_args(%arg3 = %arg1) -> (memref<?x?x1x8xf32, strided<[?, ?, ?, ?], offset: ?>>) {
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<256x256x1x8xf32>
memref.copy %alloc, %alloc_0 : memref<256x256x1x8xf32> to memref<256x256x1x8xf32>
scf.for %arg4 = %c0 to %c9 step %c1 {
%1 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, false]} : memref<1x?xf32, strided<[?, ?], offset: ?>>, vector<1x8xf32>
vector.transfer_write %1, %alloc_0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x8xf32>, memref<256x256x1x8xf32>
}
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x256x1x8xf32>
memref.copy %alloc_0, %alloc_1 : memref<256x256x1x8xf32> to memref<256x256x1x8xf32>
%cast = memref.cast %alloc_1 : memref<256x256x1x8xf32> to memref<?x?x1x8xf32, strided<[?, ?, ?, ?], offset: ?>>
scf.yield %cast : memref<?x?x1x8xf32, strided<[?, ?, ?, ?], offset: ?>>
}
memref.dealloc %alloc : memref<256x256x1x8xf32>
return %0 : memref<?x?x1x8xf32, strided<[?, ?, ?, ?], offset: ?>>
}
}
The presence of these memref.copy
operation is due, it seems, to the fact that the bufferization analysis is (wrongly, I think) concluding that the tensor.cast
operation is bufferizing to a memory write, and this creates a RaW conflict.
I believe this happens because the call to findValueInReverseUseDefChain
in defaultResultBufferizesToMemoryWrite
(mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp:681
) implicitly sets alwaysInsertLeaves
to true (it is the default value for this parameter), hence always returning the operand of the tensor.cast
(i.e. the tensor.empty()
operation, even though this is not bufferizing to a memory write), which makes defaultResultBufferizesToMemoryWrite
return true for this tensor.cast operation.
In 1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5 the call to findValueInReverseUseDefChain
has been changed to set alwaysInsertLeaves
to false
and I think this should also happen here. Could someone knowledgeable about bufferization confirm my assumptions?
Thanks!