Right, this plays with the tensor vs. memref discussion, and why I mention bufferization, as you mention in your example.
I think of copy
here more as an adjective than a verb. That’s fine with value semantics.
%copy = linalg.copy ins(%a) outs(%b)
- Yes:
%b
“is a copy of” %a
- No: copy
%b
“into” %a
That’s my option (3) above. I’m happy with either keeping (1) or renaming (3), but not having multiple of them (2).
The bufferization interplay isn’t just about bufferization itself, but what happens when we “fuse” the cast into a lower op.
Silly example (non practical):
// Down cast %arg0 to f16
%0 = tensor.empty() : tensor<32x64xf16>
%arg0_cast = linalg.cast ins(%arg0) outs(%0) : tensor <32x64xf32> to tensor<32x64xf16>
// arg1 is a constant
%arg1 = arith.constant 1.23456e1 : tensor<32x64xf16>
// Perform the add
%1 = tensor.empty() : tensor<32x64xf16>
%add = linalg.add ins (%arg1, %arg0_cast : tensor<32x64xf16>, tensor<32x64xf16>) outs (%1 : tensor<32x64xf16>) : tensor<32x64xf16>
// But accumulate on f32
%2 = tensor.empty() : tensor<32x64xf16>
%add_cast = linalg.cast ins(%add) outs (%2) : tensor<32x64xf16> to tensor<32x64xf32>
return %add_cast
Problem #1: There may be graph patterns that need to be lowered in a particular convoluted way (as above) that may be destroyed by intermediate passes, like bufferization.
Solution #1: We can always revert back to generics with precise semantics (ex. accumulation type), but then we’re also back to matching too many things. I’d like to have a design that is robust to both cases without reverting to generics.
Note that the comment says: “accumulate on f32” but the implementation accumulates on f16, then casts to f32. Breaking the pattern may mean I can’t match to a fast op down the road.
That code could bufferize to:
// Down cast %arg0 to f16
%0 = memref.alloc() : memref<32x64xf16>
%arg0_cast = linalg.cast ins(%arg0) outs(%0) : memref <32x64xf32> to memref<32x64xf16>
// arg1 is a constant
%arg1 = arith.constant 1.23456e1 : memref<32x64xf16>
// Perform the add
%1 = memref.alloc() : memref<32x64xf16>
%add = linalg.add ins (%arg1, %arg0_cast : memref<32x64xf16>, memref<32x64xf16>) outs (%1 : memref<32x64xf16>) : memref<32x64xf16>
// But accumulate on f32
%2 = memref.alloc() : memref<32x64xf16>
%add_cast = linalg.cast ins(%add) outs (%2) : memref<32x64xf16> to memref<32x64xf32>
return %add_cast
Problem #2: Bufferization can try to be smart about this and reuse some of those buffers, but then it would create variations on the number of allocations within a pattern.
Solution #2: What we do today is completely ignore the allocations in the middle and if we match, we clean up the remaining ones, if not we just lower pessimistically.
What I really wanted is:
// Some lowering with wider accumulation type
%arg1 = arith.constant 1.23456e1 : memref<32x64xf16>
%out = memref.alloc() : memref<32x64xf16>
%add = foobar.add_from_f16_and_f32_into_f32 ins (%arg1, %arg2 : memref<32x64xf16>, memref<32x64xf32>) outs (%out : memref<32x64xf32>) : memref<32x64xf32>
return %add
Note how all allocs disappear “as if they weren’t necessary after all”.
The four main ways of doing this are:
- Straight from graph to hardware and not having a compiler.
- Generic implementation and lack of complex structures (ie perfectly nested).
- Named op for everything and combinatorial explosion.
- Named op chain complexity and interplay with other passes.
We have enough anti-patterns from 1-3, but we should be careful with 4 and avoid obvious patterns.
Our current approach is to mix them all, but that also comes with a bit of all complexites.