Linalg on Tensors Update and Comprehensive Bufferization RFC.
This post is a followup to the posts:
Status Update
In the past few months, some of us have been exploring raising the level of abstraction of Linalg-based transformations. In a nutshell, this mostly amounts to performing more transformations in tensor-land and taking advantage of destructive update patterns on tensors, conceptually similar to LLVM’s extract/insertelement and extract/insertvalue.
In this context, tiling produces scf.for
+ tensor yields and SSA use-def chains are still available. This makes a certain number of transformations very natural to express:
- Padding and packing transformations become quite natural to write with a simple backward-slice based algorithm.
- Fusion + vectorization extend quite naturally to the tensor domain and the subsequent L/S forwardings become simple canonicalizations + DCE thanks to SSA use-def chains. Similarly, RAW/WAR/WAW optimizations are also quite simple to canonicalize on n-D vectors.
- Bufferization can now happen very late (i.e. after parallelization, fusion, tiling and vectorization have taken place). This gives extra phase-ordering flexibility and opens new tradeoffs as well as simplifies “in-place” reuse of buffers.
- Consequently, it became relatively simple to build an end-to-end python-driven sandbox and have it run end-to-end; and starting to connect with parallel GPU and CPU execution.
This third point is the one of interest in this post. After some initial experiments it became quite clear that the bufferization needs for Linalg on tensors and HPC codegen are quite separable from higher-level considerations related to generality of control-flow, open set of ops, composability of dialect-specific bufferization, runtime reference counting, etc… . On the other hand, in-place bufferization guarantees are highly desirable.
Back to the drawing board and experimentation, the outcome is a comprehensive bufferization pass centered around “linalg ops + scf.for
with tensor yields + function calls”. This is Module pass that takes advantage of SSA use-def chains to determine which results are “inplaceable” (i.e. may end up reusing an operand/argument buffer to write a result).
It consists in the following steps.
Step 1: Inter-procedural CallOp analysis
First, perform a funcArgumentsInPlaceAnalysis
which traverses all CallOp
s and determine whether any tensor operand could potentially bufferize to a buffer that can be updated inPlace (i.e. an in-out buffer).
Such operands are ones whose value is not read by any other subsequent op at the caller site.
As a result of this analysis, CallOp operands are marked with kInPlaceResultsAttrName
.
The “meet” of all kInPlaceResultsAttrName
for all CallOp
s to a given FuncOp
determines the kInPlaceResultsAttrName
for that FuncOp
.
In the current implementation, a topological sort of CallOp
and FuncOp
is performed and recursion is disallowed.
The kInPlaceResultsAttrName
is also the mechanism (ab?)used at the compiler/runtime interface to allow inplace interop with e.g. numpy and pytorch as follows:
func @main(%A : tensor<{M}x{K}xf32>, %B : tensor<{K}x{N}xf32>, %C : tensor<{M}x{N}xf32>, %iters : index) -> tensor<{M}x{N}xf32>
attributes {{
__writeable_func_buffer_args_attr__ = ["false", "false", "true"] }}
{{
%c0 = constant 0
%c1 = constant 1
%res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%iterC = %C) -> (tensor<{M}x{N}xf32>) {{
%r = call @matmul_on_tensors(%A, %B, %iterC) :
(tensor<{M}x{K}xf32>, tensor<{K}x{N}xf32>, tensor<{M}x{N}xf32>) -> (tensor<{M}x{N}xf32>)
scf.yield %r : tensor<{M}x{N}xf32>
}}
return %res : tensor<{M}x{N}xf32>
}}
argument %C
is marked as “inplaceable” and the compiler may use %bufferC
to write the result, assuming intra-function analysis allows it.
At the moment, may use %bufferC
is implemented as must use %bufferC
; this will be relaxed on a per-need basis.
Step 2: Intra-procedural bufferization
Next, traverse each FuncOp and perform bufferization within the function boundaries. Bufferization occurs by:
- performing an inPlace analysis
inPlaceAnalysisFuncOpInternals
which marks each operation within the function with thekInPlaceResultsAttrName
attribute. - traversing each operation in the function and rewriting it in buffer form and keeping a BlockAndValueMapping mapping of the rewrites. New allocations are introduced during this step.
The analysis uses special op knowledge of which operand / results may be inplaceable. At the moment a hardcoded enumeration is performed, in the future it seems reasonable to introduce an interface to encode this in a less adhoc fashion.
scf.for
and CallOp + FuncOp
are special as we additionally need to analyze how operand/argument #n
flows into result #n
to ensure proper inplace behavior. These operations also have special “inplaceable” semantics in combination with subtensor
/subtensor_insert
and vector.transfer_read
/vector.transfer_write
. Analyses of these patterns take SSA use-def chains (point 2. in the Status Update section).
Step 3: Function boundary and call-site bufferization
Lastly, once bufferization within function boundaries is done, the next step runs bufferizeFunctionsAndCalls
, which involves:
- detecting
function_arg -> memref.buffer_cast -> memref.tensor_load -> return
patterns for each FuncOp, which determines thetiedResultMap
between function args and results. In the future these will disappear as the semantics of those ops is very brittle. - rewriting function arguments and returns in buffer form, skipping the tensors that appear in the
tiedResultMap
. - bufferizing the CallOps using the callee’s
tiedResultMap
.
This last step is purely mechanical.
Proposal
After we experimented with end-to-end paths for a few weeks and started connecting this to parallel GPU and CPU execution, a few of us are relatively confident that this should be upstreamed.
As IREE and XLA continue moving forward relying on this end-to-end codegen-on-tensors path, landing such a bufferization in core will increase reuse and reduce risk.
From a purely technical perspective, this could be landed as a separate -linalg-comprehensive-bufferize
pass that can be used independently from the existing core bufferization and provide end-to-end batteries to linalg-on-tensors.
Some have expressed that IREE and XLA really want to handle allocations as well as what happens at function boundaries and keep composability with existing bufferization (XLA-only). So far, this has been a non-goal because the decisions made by -linalg-comprehensive-bufferize
seem to be the complete opposite of what bufferization currently does in core.
I believe this can be sliced at the function boundary without delays by temporarily relying on the kInPlaceResultsAttrName
attribute, as is done in the current implementation. Still, decoupling inter and intra function bufferization should be an implementation detail and should not preclude also having a Linalg-specific Module pass that brings end-to-end execution capabilities in core.
Once others have gotten their feet wet with this new approach, I believe we will finish connecting the pieces together so that XLA and IREE may use as much (or as little) of this new strategy as wanted.
@pifon2a @ThomasRaoux @MaheshRavishankar @_sean_silva @herhut @benvanik @stellaraccident @shabalin
Thanks for reading, please discuss!