Folks, I have been studying the routes available for interfacing MLIR/PyTorch, and some things are becoming clear. I am going to attempt to enunciate them and set some terminology. Then I am going to conclude with recommendations on the path forward.
First, referencing the meeting we had on Monday, most of us on the call had implemented some level of PyTorch program extraction in the past. It is my belief that a naïve converter, focused primarily at either the device level (such as the ATen/ACAP work) or the forward-only TorchIR level (such as what TVM has) is something that is in the realm of an appropriately motivated engineer to get to a point that a handful of models can be extracted and run in O(weeks). Beyond initial mechanics, the point is quickly hit where lowerings for various ops dominates the work.
However, I think that the gap between these “60%” implementations and an actual, high-fidelity program extraction is quite large, requiring investment in design concepts which need to be carefully modeled and are quite a bit more than “just ops”. As noted on the call, taking such things the rest of the way is not within the realm of investment that “small companies”/efforts can take on and expect to deliver robustly (i.e. when trying to create compilers for their hardware platform). Doing the high-fidelity work as part of the MLIR ecosystem seems to have value, even if it takes time to emerge completely, as it will unlock a lot of cases that are currently being cobelled together, and hopefully, we can approach an outcome where the fidelity is high enough that a lot of workloads (both training and inference) just work for people.
Design Gaps
I’ve consistently observed while looking into this that a lot of PyTorch converters tend to paper over some of the following design points:
- ATen tensors are mutable and have layout (i.e. they are much closer to a Numpy
ndarray
vs thetensor
type of MLIR) - Torch tensors carry gradients (i.e. stateful operations are executed to instruct them to materialize gradients on specific tensors and the IR presumes this structure for tensors)
- Torch ops are responsible for gradient (i.e. gradient ops are not materialized in the IR but are the responsibility of the forward ops themselves, sometimes delegating to specific backward ops as needed)
- Torch values follow python semantics and generally do not translate directly to existing primitives in the compiler framework (in full generality)
- A vast majority of the ops have variants that can alias or assign into arbitrary tensors (and conversely, most have forms that are completely immutable, but judgment needs to be applied when laying these out for a compiler)
In my opinion, it is the hallmark of the previous generation of solutions to ignore some or all of the above and still find some subset of really important cases (mostly inference-only or single basic-block programs). I think that settling for such a design would be a shame, and it leaves out a lot of things that I think are important for future ML workloads (i.e. no distinction between training and serving, algorithmic solutions, combinations of models that carry state, etc). I don’t think it is necessary to model all of the above in one step, and there is plenty of room to be pragmatic about prioritizing things that get us to runnable sooner, but we should consider designs that scale to letting us solve the above and achieve high-fidelity program extraction from this really important/loved platform.
Comparison to historic npcomp
I’ll note that many of the above (notably excepting gradients) applies to a faithful modeling of python/numpy in a compiler as well, and some of the path I had started down for solving that problem seems directly applicable here. In fact, starting with PyTorch instead of raw python seems to be quite an easier thing because that frontend already does enough preliminary python compilation to let us just focus on the numeric parts, leaving the construction of a general purpose Python compiler to future work (i.e. type inference, AST extraction, and a lot of the general purpose wrangling is already nicely modeled in the Torch IR). For reference, much of this layer exists in the Basicpy
dialect and associated machinery.
The Numpy
dialect and ops are more relevant as they started to bridge the gap from general use of mutable arrays to lower-level value-semantic tensors and handled other things such as type/shape erased arrays, dtype inference, etc. Admittedly, much of what is here is primitive, but care was taken to model it to generalize.
Value-semantic core ops
Both numpy and PyTorch define most of their core operations semantically in terms of value semantics, even though, during eager execution, everything is reference semantics. When modeling their ops in MLIR, it is possible to perform trivial expansions and wrappings to preserve these semantics while also directly supporting the forms that operate with reference semantics.
For PyTorch examples, consider the very common patterns of value-semantic ops that take an out=
parameter or inplace “underscore” ops (numpy has both of these forms as well but with different syntax). These distinctions are modeled in the IR but are trivially decomposed to value semantic core ops and manipulator ops that re-introduce reference semantics.
Example:
a = tensor(...)
b = torch.add(x, y, out=a)
print(a)
Could be:
%1 = numpy.create_ndarray(...)
%2 = numpy.copy_to_tensor %1
%3 = aten.add %x, %y, %one
numpy.store_tensor %1, %3
Similar decompositions exist for in-place ops. The advantage to keeping the core linear algebra op-library in the value domain is that we maintain a significant amount of guarantees about aliasing and let the compiler optimize the memory patterns further down-pipeline. More practically, this lines up with how most of the infra is built. It is trivial to elide the copies/stores for most forward-only programs and fairly well-known algorithms exist to further de-reference the less-trivial cases. In the worst case, the program does preserve some explicit copies and mutable reference types and we can form them into islands and still treat the insides with value semantics.
There are a number of ops that are inherently reference based and will remain so, requiring dedicated analysis downstream to do the right thing (i.e. as_strided
, expand
, view
, etc). By working in the ndarray
/reference domain at the top-level, these are easily captured and it remains the work of the compiler to legalize them as needed (and many trivial cases exist that don’t alias where this poses no complexity/penalty in practice).
I’ve implemented this scheme for importing Python/Numpy programs, and the importer becomes a really simple state machine which keeps some identity map between producers in the source IR and values in the MLIR side (in other words, these type manipulations are not burdensome to paper over at import time and allow us to start with a more semantic-laden IR once it hits MLIR).
Gradients
Consider the program:
def fn(z):
y = z.exp()
y.backward(create_graph=True)
return z.grad
z = torch.tensor(1., requires_grad=True)
fn_jit = torch.jit.script(fn)
print(fn_jit.graph)
Which produces:
graph(%z.1 : Tensor):
%5 : None = prim::Constant()
%4 : bool = prim::Constant[value=1]()
%y.1 : Tensor = aten::exp(%z.1)
= aten::backward(%y.1, %5, %5, %4)
%9 : Tensor = prim::grad(%z.1)
return (%9)
Implied in the very modeling is the knowledge that tensors capture graph structures that can be forced-evaluated into gradients on demand. This is different from TensorFlow (for example), where the engine just directly emits the gradient ops that perform the evaluation into its IR (GraphDef). It is a common question on the PyTorch forums regarding how to achieve this same end. The only advice (so far) is that the internal symbolic autodiff facility can do this, but it is incomplete, unstable, etc.
There is no getting around that any truly complete compiler for PyTorch will need to bring some of its own auto differentiation support, hopefully, getting some help from the framework where possible/necessary (i.e. the above API is the primary/only place to get expansions for custom/script-defined forward/backward pairs).
I suspect that we will be going through quite a few rounds of getting this part right. However, the one invariant is that for each node imported, if PyTorch can symbolically differentiate it, we have to capture that (i.e. isDifferentiable(Node *)
) in our IR, or else we will be missing information. I propose starting with a simple mechanism such as:
%output, %grad = torch.symbolic_grad for {
%1 = aten.native_batch_norm ...
return %1
} with gradients {
%2 = aten.native_batch_norm_backward ...
return %2
}
This minimal form will be necessary for the construction of any real MaterializeGradients
pass.
Comparison to the existing ATen
dialect
The current ATen
dialect, contributed by Xilinx is (intended to be) isomorphic with ATen
library calls. This is necessary but not sufficient for full program extraction. The more I look at this, the more I see this dialect in a stack of dialects:
-
Basicpy
: High-level Python types and operations. -
Numpy
:NdArray
type, facilities for managing dtype, transforms for dtype inference. -
Torch
: Import dialect for full fidelity Torch IR with ops built-in for theprim
namespace, structural ops for managing ref/value conversions, wrapper ops for gradients, etc.Torch
→Torch
transforms exist to simplify this IR in a variety of ways (i.e. expand gradients, drop gradients, simplify across static python constructs, infer ranks, eliminate unnecessary references, etc). -
ATen
: Isomorphic with ATen library calls and op names in theaten::
namespace of TorchIR. We maintain facilities for record and playback of ops in this dialect to the ATen library, and we maintain conversion patterns that can lower to LLVM IR that invokes the corresponding ATen libraries. When ATen library calls represent different data-flows based on arg presence/absence (i.e. store to another tensor, operate in-place, etc), these should be expanded to different ops with different signatures (it seems like these overloads can always be differentiated based on operand arity).
Next steps:
-
I’d like to land my current work to generate the ATen builtin ops from the python side (i.e. the scripts that are there now are highly version locked to an obsolete version of PyTorch). A few simplifying questions:
- [Q]: Can we simplify our lives in the short term by dropping the
StatisticsOpInterface
that these all implement? I’d like to hear more from Xilinx about what utility this is providing at this level (I understand the utility in general – it just seems like we would actually want this more at a lower level). - [Q]: The existing modeling of the reference ops (i.e. as_strided and inplace) isn’t right. Was there a reason these were represented purely in the value domain?
- [Q]: Can we simplify our lives in the short term by dropping the
-
Start work on the
Torch
dialect and build out the boiler-plate for an importer. -
Implement the importer for the
aten::
op namespace in Torch IR. This should be largely systematic/derived from data structures created when we generated the ODS. -
Pick the simplest possible case (maybe a MLP) and work it through the Torch IR import path, implementing enough transformations to legalize it to TCF (that Sean is working up towards).
I realize this is a lot of different directions of thought. I’m happy to talk about any of it more but I felt it important to at least checkpoint my interim design notes.