Torch dialects, gradients (and bears, oh my!)

Folks, I have been studying the routes available for interfacing MLIR/PyTorch, and some things are becoming clear. I am going to attempt to enunciate them and set some terminology. Then I am going to conclude with recommendations on the path forward.

First, referencing the meeting we had on Monday, most of us on the call had implemented some level of PyTorch program extraction in the past. It is my belief that a naïve converter, focused primarily at either the device level (such as the ATen/ACAP work) or the forward-only TorchIR level (such as what TVM has) is something that is in the realm of an appropriately motivated engineer to get to a point that a handful of models can be extracted and run in O(weeks). Beyond initial mechanics, the point is quickly hit where lowerings for various ops dominates the work.

However, I think that the gap between these “60%” implementations and an actual, high-fidelity program extraction is quite large, requiring investment in design concepts which need to be carefully modeled and are quite a bit more than “just ops”. As noted on the call, taking such things the rest of the way is not within the realm of investment that “small companies”/efforts can take on and expect to deliver robustly (i.e. when trying to create compilers for their hardware platform). Doing the high-fidelity work as part of the MLIR ecosystem seems to have value, even if it takes time to emerge completely, as it will unlock a lot of cases that are currently being cobelled together, and hopefully, we can approach an outcome where the fidelity is high enough that a lot of workloads (both training and inference) just work for people.

Design Gaps

I’ve consistently observed while looking into this that a lot of PyTorch converters tend to paper over some of the following design points:

  • ATen tensors are mutable and have layout (i.e. they are much closer to a Numpy ndarray vs the tensor type of MLIR)
  • Torch tensors carry gradients (i.e. stateful operations are executed to instruct them to materialize gradients on specific tensors and the IR presumes this structure for tensors)
  • Torch ops are responsible for gradient (i.e. gradient ops are not materialized in the IR but are the responsibility of the forward ops themselves, sometimes delegating to specific backward ops as needed)
  • Torch values follow python semantics and generally do not translate directly to existing primitives in the compiler framework (in full generality)
  • A vast majority of the ops have variants that can alias or assign into arbitrary tensors (and conversely, most have forms that are completely immutable, but judgment needs to be applied when laying these out for a compiler)

In my opinion, it is the hallmark of the previous generation of solutions to ignore some or all of the above and still find some subset of really important cases (mostly inference-only or single basic-block programs). I think that settling for such a design would be a shame, and it leaves out a lot of things that I think are important for future ML workloads (i.e. no distinction between training and serving, algorithmic solutions, combinations of models that carry state, etc). I don’t think it is necessary to model all of the above in one step, and there is plenty of room to be pragmatic about prioritizing things that get us to runnable sooner, but we should consider designs that scale to letting us solve the above and achieve high-fidelity program extraction from this really important/loved platform.

Comparison to historic npcomp

I’ll note that many of the above (notably excepting gradients) applies to a faithful modeling of python/numpy in a compiler as well, and some of the path I had started down for solving that problem seems directly applicable here. In fact, starting with PyTorch instead of raw python seems to be quite an easier thing because that frontend already does enough preliminary python compilation to let us just focus on the numeric parts, leaving the construction of a general purpose Python compiler to future work (i.e. type inference, AST extraction, and a lot of the general purpose wrangling is already nicely modeled in the Torch IR). For reference, much of this layer exists in the Basicpy dialect and associated machinery.

The Numpy dialect and ops are more relevant as they started to bridge the gap from general use of mutable arrays to lower-level value-semantic tensors and handled other things such as type/shape erased arrays, dtype inference, etc. Admittedly, much of what is here is primitive, but care was taken to model it to generalize.

Value-semantic core ops

Both numpy and PyTorch define most of their core operations semantically in terms of value semantics, even though, during eager execution, everything is reference semantics. When modeling their ops in MLIR, it is possible to perform trivial expansions and wrappings to preserve these semantics while also directly supporting the forms that operate with reference semantics.

For PyTorch examples, consider the very common patterns of value-semantic ops that take an out= parameter or inplace “underscore” ops (numpy has both of these forms as well but with different syntax). These distinctions are modeled in the IR but are trivially decomposed to value semantic core ops and manipulator ops that re-introduce reference semantics.

Example:

a = tensor(...)
b = torch.add(x, y, out=a)
print(a)

Could be:

%1 = numpy.create_ndarray(...)
%2 = numpy.copy_to_tensor %1
%3 = aten.add %x, %y, %one
numpy.store_tensor %1, %3

Similar decompositions exist for in-place ops. The advantage to keeping the core linear algebra op-library in the value domain is that we maintain a significant amount of guarantees about aliasing and let the compiler optimize the memory patterns further down-pipeline. More practically, this lines up with how most of the infra is built. It is trivial to elide the copies/stores for most forward-only programs and fairly well-known algorithms exist to further de-reference the less-trivial cases. In the worst case, the program does preserve some explicit copies and mutable reference types and we can form them into islands and still treat the insides with value semantics.

There are a number of ops that are inherently reference based and will remain so, requiring dedicated analysis downstream to do the right thing (i.e. as_strided, expand, view, etc). By working in the ndarray/reference domain at the top-level, these are easily captured and it remains the work of the compiler to legalize them as needed (and many trivial cases exist that don’t alias where this poses no complexity/penalty in practice).

I’ve implemented this scheme for importing Python/Numpy programs, and the importer becomes a really simple state machine which keeps some identity map between producers in the source IR and values in the MLIR side (in other words, these type manipulations are not burdensome to paper over at import time and allow us to start with a more semantic-laden IR once it hits MLIR).

Gradients

Consider the program:

def fn(z):
  y = z.exp()
  y.backward(create_graph=True)
  return z.grad


z = torch.tensor(1., requires_grad=True)
fn_jit = torch.jit.script(fn)
print(fn_jit.graph)

Which produces:

graph(%z.1 : Tensor):
  %5 : None = prim::Constant()
  %4 : bool = prim::Constant[value=1]()
  %y.1 : Tensor = aten::exp(%z.1)
   = aten::backward(%y.1, %5, %5, %4)
  %9 : Tensor = prim::grad(%z.1)
  return (%9)

Implied in the very modeling is the knowledge that tensors capture graph structures that can be forced-evaluated into gradients on demand. This is different from TensorFlow (for example), where the engine just directly emits the gradient ops that perform the evaluation into its IR (GraphDef). It is a common question on the PyTorch forums regarding how to achieve this same end. The only advice (so far) is that the internal symbolic autodiff facility can do this, but it is incomplete, unstable, etc.

There is no getting around that any truly complete compiler for PyTorch will need to bring some of its own auto differentiation support, hopefully, getting some help from the framework where possible/necessary (i.e. the above API is the primary/only place to get expansions for custom/script-defined forward/backward pairs).

I suspect that we will be going through quite a few rounds of getting this part right. However, the one invariant is that for each node imported, if PyTorch can symbolically differentiate it, we have to capture that (i.e. isDifferentiable(Node *)) in our IR, or else we will be missing information. I propose starting with a simple mechanism such as:

%output, %grad = torch.symbolic_grad for {
  %1 = aten.native_batch_norm ...
  return %1
} with gradients {
  %2 = aten.native_batch_norm_backward ...
  return %2
}

This minimal form will be necessary for the construction of any real MaterializeGradients pass.

Comparison to the existing ATen dialect

The current ATen dialect, contributed by Xilinx is (intended to be) isomorphic with ATen library calls. This is necessary but not sufficient for full program extraction. The more I look at this, the more I see this dialect in a stack of dialects:

  • Basicpy: High-level Python types and operations.
  • Numpy: NdArray type, facilities for managing dtype, transforms for dtype inference.
  • Torch: Import dialect for full fidelity Torch IR with ops built-in for the prim namespace, structural ops for managing ref/value conversions, wrapper ops for gradients, etc. TorchTorch transforms exist to simplify this IR in a variety of ways (i.e. expand gradients, drop gradients, simplify across static python constructs, infer ranks, eliminate unnecessary references, etc).
  • ATen: Isomorphic with ATen library calls and op names in the aten:: namespace of TorchIR. We maintain facilities for record and playback of ops in this dialect to the ATen library, and we maintain conversion patterns that can lower to LLVM IR that invokes the corresponding ATen libraries. When ATen library calls represent different data-flows based on arg presence/absence (i.e. store to another tensor, operate in-place, etc), these should be expanded to different ops with different signatures (it seems like these overloads can always be differentiated based on operand arity).

Next steps:

  • I’d like to land my current work to generate the ATen builtin ops from the python side (i.e. the scripts that are there now are highly version locked to an obsolete version of PyTorch). A few simplifying questions:

    • [Q]: Can we simplify our lives in the short term by dropping the StatisticsOpInterface that these all implement? I’d like to hear more from Xilinx about what utility this is providing at this level (I understand the utility in general – it just seems like we would actually want this more at a lower level).
    • [Q]: The existing modeling of the reference ops (i.e. as_strided and inplace) isn’t right. Was there a reason these were represented purely in the value domain?
  • Start work on the Torch dialect and build out the boiler-plate for an importer.

  • Implement the importer for the aten:: op namespace in Torch IR. This should be largely systematic/derived from data structures created when we generated the ODS.

  • Pick the simplest possible case (maybe a MLP) and work it through the Torch IR import path, implementing enough transformations to legalize it to TCF (that Sean is working up towards).

I realize this is a lot of different directions of thought. I’m happy to talk about any of it more but I felt it important to at least checkpoint my interim design notes.

2 Likes

Thanks Stella for this great write-up! Your suggested representations seem reasonable, and +1 on getting the semantic modeling right, even if our analyses are not sophisticated enough for handling more complicated cases yet (it’s always easier to write a pass/analysis after “look at this repro; it fails to compile because we aren’t smart enough to deduce X; we need to make pass Y smarter”).

The next steps seem reasonable to me. Looking forward to connecting TCF to this!

1 Like

Thanks for this, Stella… I’d really like to figure out a reasonable path forward here.

Handling of the backward path with gradients is one of the big advantages of the the ‘device-based’ approach, rather than the ‘direct compilation’ approach. Using the device based approach, the backward path is realized by Pytorch. I think that other aspects, such as shape propagation have also been taken care of by Pytorch, so we always have ranked and sized tensors. Obviously, there are advantages and disadvantages to this: there is no question of having to implement autograd ourselves, for instance. On the other hand, we can’t compile a program ahead of time without concrete arguments, since shapes have to be provided. I suspect that there is not a ‘one right answer’ here, but different approaches with different tradeoffs.

The second major question seems to be around value-based vs reference-based data structures. This arises in at least 2 cases: 1) in-place operations with underscores 2) view/as_strided operations. My thinking has been that in-place operations would be represented, but possibly less optimizable than value-based operations, but I suspect that the ownership model for how this is handled today has not been rigorously thought out. Our focus has been on models with value-based operations.

Agreed - the more I look at this, the more I think that the ‘device-based’ approach you all have taken is a really useful part of the equation to get right. That gives us the ATen-level op-by-op playback and record capabilities that we can both use directly in some/a lot of cases and we can build on that to have ATen fallback of AOT programs (which is likely needed for full generality, even if to be avoided).

Thanks for bearing with me – a lot of what I’ve done/written so far on this falls into the category of really grokking the total scope, and I think I’m nearly at the point to start coding for real. Getting what you all contributed expanded and a bit more future proof is likely going to pay dividends. I’m personally motivated by the higher level compiler integration, but I think we need to invest more in what you all brought before going too far in that direction (i.e. everything bottoms out on the device-based setup in full generality).

1 Like

About the gradient. I think taking the gradient of the original graph is suboptimal.
In the past, this is what was done in Theano. But it made the optimization pass more complicated than needed in some cases. What I think would be better is:

  • Add a gradient nodes in the graph when the gradient is needed without generating the full gradient graph.
  • Optimize somewhat the forward. Mostly graph clean up passes like constant folding, CSE, canonicalization, numerical stability optimization (if that it done)…
  • Then the compiler generates the gradient graph from the cleaner graph.
  • “Restart” the compilation pipeline to clean up the new backward graph and finish the compilation.

This will generates a simpler gradient graph that will be easier to optimize later.
So implementing the gradient generation at the MLIR level instead of the PyTorch level have some advantages. It probably also have inconvenience like needing a full implementation vs reusing existing code.

1 Like

Thanks, this has been my instinct as well, but I’ve been trying to figure out whether there is a way that doesn’t involve reimplementing. Knowing what Theano found is helpful to avoid the same pitfalls.

This would be similar to how we lazily materialize shapes based on something like shape.get_shape inserted as needed. The trick is doing the materialization at a point when enough information exists/the computation is in a form that can still have that done. I don’t have that intuition for gradient generation.

Thanks for the feedback!

Reviving this older thread, I am curious what the current philosophy on supporting gradients in torch-mlir is? Has the focus shifted back to inference-only compilation, or are there best practices for computing tensor gradients during the forward pass with a MLIR pipeline?

We should sync with the pytorch side. I’ve been watching patches and they are getting close to having this and the optimizers supported in stock fx. We’ve got some workarounds we do and I’ve been mostly waiting for torch to do it right.

With LLM fine tuning workloads, they finally have a product driver for it on the Meta side, so I think it’ll happen.

1 Like

cc @ftynse and thanks @mehdi_amini for the bump to the thread.

So the Enzyme-MLIR project for generic automatic differentiation (AD) of MLIR has been progressing quite a lot recently for bringing all of the LLVM AD from Enzyme and more to MLIR.

We’ve gotten generic forward and reverse mode AD implemented as opt-in interfaces for MLIR that successfully produce derivatives for arith/math/scf/affine/stablehlo/whatever dialect you want to opt into (though we haven’t added everything for all these dialects yet, but at least core ones we’ve hit – ofc contributions welcome).

A couple of examples here from tests (which we do need to write more of):

Most ops can have their derivatives defined in tablegen (see here for an example Enzyme/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td at main · EnzymeAD/Enzyme · GitHub or here for HLO Enzyme-JAX/src/enzyme_ad/jax/Implementations/HLODerivatives.td at main · EnzymeAD/Enzyme-JAX · GitHub), with as much shared code as possible (e.g. auto-defining both forward and reverse derivatives when legal), but with sufficient flexibility for your favorite custom op to have its semantics added (e.g. Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp at c3cd279e4fab29d742b7ce7b7187b9e9c9668f5b · EnzymeAD/Enzyme-JAX · GitHub ).

You’re welcome to use (and especially help contribute to) the effort if that’s helpful and we have some good early results for use in ML frameworks.

We’re planning on opening an RFC for Enzyme to become an llvm subproject in the relative future (hopefully moving both LLVM and MLIR AD, which do share infra), maybe early summerish?

Before then, there’s a couple of things we’re iterating on designs for before upstreaming, just to make development velocity easier (simplifying interproceudral reverse mode MLIR, upstreaming our augmentations to LLVM scalar evolution: change contents of ScalarEvolution from private to protected by skewballfox · Pull Request #83052 · llvm/llvm-project · GitHub, etc).

6 Likes

Thanks for all the replies! It looks like all the technology is there, but batteries are not included yet :wink:

Can somebody educate me a bit more on how we eventually want to make this available to torch-mlir users as the “high-fidelity program extraction” plugin-and-play support Stella is alluding to above.

Suppose we have the following net:

    class BikNet(torch.nn.Module):
        def __init__(self):
            super(BikNet, self).__init__()

        def forward(self, a, b):
            # q = 3 a^3 - b^2 + 10
            #   with dq/da = 9a^2
            #   with dq/db = -2b
            return 3 * a * a * a - b * b + 10

Which we call as follows

    net = BikNet()
    a = torch.tensor([2., 3.], requires_grad=True)
    b = torch.tensor([6., 4.], requires_grad=True)
    q = net(a, b)    // forward pass
    q.sum().backward()

The current FX importer yields

module {
  func.func @main(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
    %int3 = torch.constant.int 3
    %0 = torch.aten.mul.Scalar %arg0, %int3 : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
    %1 = torch.aten.mul.Tensor %0, %arg0 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>
    %2 = torch.aten.mul.Tensor %1, %arg0 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>
    %3 = torch.aten.mul.Tensor %arg1, %arg1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>
    %int1 = torch.constant.int 1
    %4 = torch.aten.sub.Tensor %2, %3, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
    %int10 = torch.constant.int 10
    %int1_0 = torch.constant.int 1
    %5 = torch.aten.add.Scalar %4, %int10, %int1_0 : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
    return %5 : !torch.vtensor<[2],f32>
  }

}

which clearly enables the MLIR compiler to implement a JIT compiler for executing the forward pass. But what mechanism will we eventually provide that ensures that if the compiler yields the results back to the PyTorch system before calling backward(), evaluating a.grad and b.grad give the desired answers (and then obviously preferably only for arguments with requires_grad=True set)?

And I (think I) understand Enzyme-MLIR provides the right tool set for computing gradients, but do we have plans to automate the torch-mlir support to take use of such a system in a completely transparent manner?

Idk. There is a missing connection with aot autograd and possibly with some torch.export level metadata that ties it together. I haven’t looked at it in six months… And my conclusion then was to check back in in six months.

1 Like

Here is a very raw example of how this can be done: mnist_train.py · GitHub

As is depressingly common in PyTorch, the pieces have existed for sometime but are only now being pulled together into something cohesive. I expect there will be a better actual API for this at some point, but in the mean time, we are using glue to:

  1. Use functorch to trace a combined forward-backward graph of a training step.
  2. Do a source level manipulation to turn this “training” graph into a normal inference forward-only module that just happens to compute gradients, reading/updating buffers that mirror the original parameters.
  3. torch.export.export this transformed module, compile it and have a hermetic training loop.

If you just want gradients instead of full optimizers, parameter updates, etc, you could use a variant of the same procedure but not including the training step.

It’s still a bit raw but we use variants of this for fine tuning and such. As I expect that it will be a while before Torch gets there with a similarly usable solution, we’ll probably commit these tools to Turbine.

The key is that the functorch tracer knows all about gradients, and modules traced that way can then be exported with all of the other goodies.

(this does rely on “de-functionalization” and mutability, which afaik, iree is the only thing that implements programs in this form currently)

1 Like