Torch dialects, gradients (and bears, oh my!)

Folks, I have been studying the routes available for interfacing MLIR/PyTorch, and some things are becoming clear. I am going to attempt to enunciate them and set some terminology. Then I am going to conclude with recommendations on the path forward.

First, referencing the meeting we had on Monday, most of us on the call had implemented some level of PyTorch program extraction in the past. It is my belief that a naïve converter, focused primarily at either the device level (such as the ATen/ACAP work) or the forward-only TorchIR level (such as what TVM has) is something that is in the realm of an appropriately motivated engineer to get to a point that a handful of models can be extracted and run in O(weeks). Beyond initial mechanics, the point is quickly hit where lowerings for various ops dominates the work.

However, I think that the gap between these “60%” implementations and an actual, high-fidelity program extraction is quite large, requiring investment in design concepts which need to be carefully modeled and are quite a bit more than “just ops”. As noted on the call, taking such things the rest of the way is not within the realm of investment that “small companies”/efforts can take on and expect to deliver robustly (i.e. when trying to create compilers for their hardware platform). Doing the high-fidelity work as part of the MLIR ecosystem seems to have value, even if it takes time to emerge completely, as it will unlock a lot of cases that are currently being cobelled together, and hopefully, we can approach an outcome where the fidelity is high enough that a lot of workloads (both training and inference) just work for people.

Design Gaps

I’ve consistently observed while looking into this that a lot of PyTorch converters tend to paper over some of the following design points:

  • ATen tensors are mutable and have layout (i.e. they are much closer to a Numpy ndarray vs the tensor type of MLIR)
  • Torch tensors carry gradients (i.e. stateful operations are executed to instruct them to materialize gradients on specific tensors and the IR presumes this structure for tensors)
  • Torch ops are responsible for gradient (i.e. gradient ops are not materialized in the IR but are the responsibility of the forward ops themselves, sometimes delegating to specific backward ops as needed)
  • Torch values follow python semantics and generally do not translate directly to existing primitives in the compiler framework (in full generality)
  • A vast majority of the ops have variants that can alias or assign into arbitrary tensors (and conversely, most have forms that are completely immutable, but judgment needs to be applied when laying these out for a compiler)

In my opinion, it is the hallmark of the previous generation of solutions to ignore some or all of the above and still find some subset of really important cases (mostly inference-only or single basic-block programs). I think that settling for such a design would be a shame, and it leaves out a lot of things that I think are important for future ML workloads (i.e. no distinction between training and serving, algorithmic solutions, combinations of models that carry state, etc). I don’t think it is necessary to model all of the above in one step, and there is plenty of room to be pragmatic about prioritizing things that get us to runnable sooner, but we should consider designs that scale to letting us solve the above and achieve high-fidelity program extraction from this really important/loved platform.

Comparison to historic npcomp

I’ll note that many of the above (notably excepting gradients) applies to a faithful modeling of python/numpy in a compiler as well, and some of the path I had started down for solving that problem seems directly applicable here. In fact, starting with PyTorch instead of raw python seems to be quite an easier thing because that frontend already does enough preliminary python compilation to let us just focus on the numeric parts, leaving the construction of a general purpose Python compiler to future work (i.e. type inference, AST extraction, and a lot of the general purpose wrangling is already nicely modeled in the Torch IR). For reference, much of this layer exists in the Basicpy dialect and associated machinery.

The Numpy dialect and ops are more relevant as they started to bridge the gap from general use of mutable arrays to lower-level value-semantic tensors and handled other things such as type/shape erased arrays, dtype inference, etc. Admittedly, much of what is here is primitive, but care was taken to model it to generalize.

Value-semantic core ops

Both numpy and PyTorch define most of their core operations semantically in terms of value semantics, even though, during eager execution, everything is reference semantics. When modeling their ops in MLIR, it is possible to perform trivial expansions and wrappings to preserve these semantics while also directly supporting the forms that operate with reference semantics.

For PyTorch examples, consider the very common patterns of value-semantic ops that take an out= parameter or inplace “underscore” ops (numpy has both of these forms as well but with different syntax). These distinctions are modeled in the IR but are trivially decomposed to value semantic core ops and manipulator ops that re-introduce reference semantics.


a = tensor(...)
b = torch.add(x, y, out=a)

Could be:

%1 = numpy.create_ndarray(...)
%2 = numpy.copy_to_tensor %1
%3 = aten.add %x, %y, %one
numpy.store_tensor %1, %3

Similar decompositions exist for in-place ops. The advantage to keeping the core linear algebra op-library in the value domain is that we maintain a significant amount of guarantees about aliasing and let the compiler optimize the memory patterns further down-pipeline. More practically, this lines up with how most of the infra is built. It is trivial to elide the copies/stores for most forward-only programs and fairly well-known algorithms exist to further de-reference the less-trivial cases. In the worst case, the program does preserve some explicit copies and mutable reference types and we can form them into islands and still treat the insides with value semantics.

There are a number of ops that are inherently reference based and will remain so, requiring dedicated analysis downstream to do the right thing (i.e. as_strided, expand, view, etc). By working in the ndarray/reference domain at the top-level, these are easily captured and it remains the work of the compiler to legalize them as needed (and many trivial cases exist that don’t alias where this poses no complexity/penalty in practice).

I’ve implemented this scheme for importing Python/Numpy programs, and the importer becomes a really simple state machine which keeps some identity map between producers in the source IR and values in the MLIR side (in other words, these type manipulations are not burdensome to paper over at import time and allow us to start with a more semantic-laden IR once it hits MLIR).


Consider the program:

def fn(z):
  y = z.exp()
  return z.grad

z = torch.tensor(1., requires_grad=True)
fn_jit = torch.jit.script(fn)

Which produces:

graph(%z.1 : Tensor):
  %5 : None = prim::Constant()
  %4 : bool = prim::Constant[value=1]()
  %y.1 : Tensor = aten::exp(%z.1)
   = aten::backward(%y.1, %5, %5, %4)
  %9 : Tensor = prim::grad(%z.1)
  return (%9)

Implied in the very modeling is the knowledge that tensors capture graph structures that can be forced-evaluated into gradients on demand. This is different from TensorFlow (for example), where the engine just directly emits the gradient ops that perform the evaluation into its IR (GraphDef). It is a common question on the PyTorch forums regarding how to achieve this same end. The only advice (so far) is that the internal symbolic autodiff facility can do this, but it is incomplete, unstable, etc.

There is no getting around that any truly complete compiler for PyTorch will need to bring some of its own auto differentiation support, hopefully, getting some help from the framework where possible/necessary (i.e. the above API is the primary/only place to get expansions for custom/script-defined forward/backward pairs).

I suspect that we will be going through quite a few rounds of getting this part right. However, the one invariant is that for each node imported, if PyTorch can symbolically differentiate it, we have to capture that (i.e. isDifferentiable(Node *)) in our IR, or else we will be missing information. I propose starting with a simple mechanism such as:

%output, %grad = torch.symbolic_grad for {
  %1 = aten.native_batch_norm ...
  return %1
} with gradients {
  %2 = aten.native_batch_norm_backward ...
  return %2

This minimal form will be necessary for the construction of any real MaterializeGradients pass.

Comparison to the existing ATen dialect

The current ATen dialect, contributed by Xilinx is (intended to be) isomorphic with ATen library calls. This is necessary but not sufficient for full program extraction. The more I look at this, the more I see this dialect in a stack of dialects:

  • Basicpy: High-level Python types and operations.
  • Numpy: NdArray type, facilities for managing dtype, transforms for dtype inference.
  • Torch: Import dialect for full fidelity Torch IR with ops built-in for the prim namespace, structural ops for managing ref/value conversions, wrapper ops for gradients, etc. Torch->Torch transforms exist to simplify this IR in a variety of ways (i.e. expand gradients, drop gradients, simplify across static python constructs, infer ranks, eliminate unnecessary references, etc).
  • ATen: Isomorphic with ATen library calls and op names in the aten:: namespace of TorchIR. We maintain facilities for record and playback of ops in this dialect to the ATen library, and we maintain conversion patterns that can lower to LLVM IR that invokes the corresponding ATen libraries. When ATen library calls represent different data-flows based on arg presence/absence (i.e. store to another tensor, operate in-place, etc), these should be expanded to different ops with different signatures (it seems like these overloads can always be differentiated based on operand arity).

Next steps:

  • I’d like to land my current work to generate the ATen builtin ops from the python side (i.e. the scripts that are there now are highly version locked to an obsolete version of PyTorch). A few simplifying questions:

    • [Q]: Can we simplify our lives in the short term by dropping the StatisticsOpInterface that these all implement? I’d like to hear more from Xilinx about what utility this is providing at this level (I understand the utility in general – it just seems like we would actually want this more at a lower level).
    • [Q]: The existing modeling of the reference ops (i.e. as_strided and inplace) isn’t right. Was there a reason these were represented purely in the value domain?
  • Start work on the Torch dialect and build out the boiler-plate for an importer.

  • Implement the importer for the aten:: op namespace in Torch IR. This should be largely systematic/derived from data structures created when we generated the ODS.

  • Pick the simplest possible case (maybe a MLP) and work it through the Torch IR import path, implementing enough transformations to legalize it to TCF (that Sean is working up towards).

I realize this is a lot of different directions of thought. I’m happy to talk about any of it more but I felt it important to at least checkpoint my interim design notes.


Thanks Stella for this great write-up! Your suggested representations seem reasonable, and +1 on getting the semantic modeling right, even if our analyses are not sophisticated enough for handling more complicated cases yet (it’s always easier to write a pass/analysis after “look at this repro; it fails to compile because we aren’t smart enough to deduce X; we need to make pass Y smarter”).

The next steps seem reasonable to me. Looking forward to connecting TCF to this!

1 Like

Thanks for this, Stella… I’d really like to figure out a reasonable path forward here.

Handling of the backward path with gradients is one of the big advantages of the the ‘device-based’ approach, rather than the ‘direct compilation’ approach. Using the device based approach, the backward path is realized by Pytorch. I think that other aspects, such as shape propagation have also been taken care of by Pytorch, so we always have ranked and sized tensors. Obviously, there are advantages and disadvantages to this: there is no question of having to implement autograd ourselves, for instance. On the other hand, we can’t compile a program ahead of time without concrete arguments, since shapes have to be provided. I suspect that there is not a ‘one right answer’ here, but different approaches with different tradeoffs.

The second major question seems to be around value-based vs reference-based data structures. This arises in at least 2 cases: 1) in-place operations with underscores 2) view/as_strided operations. My thinking has been that in-place operations would be represented, but possibly less optimizable than value-based operations, but I suspect that the ownership model for how this is handled today has not been rigorously thought out. Our focus has been on models with value-based operations.

Agreed - the more I look at this, the more I think that the ‘device-based’ approach you all have taken is a really useful part of the equation to get right. That gives us the ATen-level op-by-op playback and record capabilities that we can both use directly in some/a lot of cases and we can build on that to have ATen fallback of AOT programs (which is likely needed for full generality, even if to be avoided).

Thanks for bearing with me – a lot of what I’ve done/written so far on this falls into the category of really grokking the total scope, and I think I’m nearly at the point to start coding for real. Getting what you all contributed expanded and a bit more future proof is likely going to pay dividends. I’m personally motivated by the higher level compiler integration, but I think we need to invest more in what you all brought before going too far in that direction (i.e. everything bottoms out on the device-based setup in full generality).

1 Like

About the gradient. I think taking the gradient of the original graph is suboptimal.
In the past, this is what was done in Theano. But it made the optimization pass more complicated than needed in some cases. What I think would be better is:

  • Add a gradient nodes in the graph when the gradient is needed without generating the full gradient graph.
  • Optimize somewhat the forward. Mostly graph clean up passes like constant folding, CSE, canonicalization, numerical stability optimization (if that it done)…
  • Then the compiler generates the gradient graph from the cleaner graph.
  • “Restart” the compilation pipeline to clean up the new backward graph and finish the compilation.

This will generates a simpler gradient graph that will be easier to optimize later.
So implementing the gradient generation at the MLIR level instead of the PyTorch level have some advantages. It probably also have inconvenience like needing a full implementation vs reusing existing code.

1 Like

Thanks, this has been my instinct as well, but I’ve been trying to figure out whether there is a way that doesn’t involve reimplementing. Knowing what Theano found is helpful to avoid the same pitfalls.

This would be similar to how we lazily materialize shapes based on something like shape.get_shape inserted as needed. The trick is doing the materialization at a point when enough information exists/the computation is in a form that can still have that done. I don’t have that intuition for gradient generation.

Thanks for the feedback!