[npcomp] Torch Dialect round-tripping

@stellaraccident: Continuing our discussion from Discord wrt to next steps for round tripping from Torch -> ATen Dialect after the ODM today…

Here’s the original message:
Thanks for the great presentation! Based on the presentation slides, it seems like we may be able to help out on the torch round-trip project. One approach may be to just hook up enough aten runtime ops to round-trip something like Resnet18 (which is just conv2d, maxpool, add, and avgpool). I think there’s already a starting point for round-tripping through aten checked in as well:


Resnet18 might serve as a good “dummy” model for us to flush out what quantization support might look like as well (since a pre-trained quantized model can be downloaded using torchvision ) as an extension of basic f32 round-trip support when we get to that point.

Where do we think would be a good place for us to begin PoC’ing a round trip network? Another question would be whether to approach from the “device capture” side or through TorchScript for the initial flush out…

1 Like

Can you please edit the message and copy the entire original post instead of linking? Not everyone has access to Discord and for archiving purpose it seems preferable to me to have all the needed context here.

The discord link has been replaced with the contents of the original post. Thanks!

1 Like

Hey @brycearden - this sounds great. Sorry with the preso prep yesterday, I got a little behind on other things and have been offline today. I can join the discussion next week. Also @_sean_silva wrote that slide and may have more ideas.

For reference, here is the slide.

Have to admit, I don’t have a super concrete implementation idea, as I’ve only recently been seeing what the layering looks like on the PT side (mostly by reviewing Stella’s patches).

I guess the big idea is: is there a way that we can “by default” run all pytorch programs by falling back to ATen, while maintaining the ability to lower through the rest of the compiler stack as it incrementally is able to cover more of the program.

I think Stella is already doing something philosophically similar by importing into raw torch.kernel_call which are then lowered to aten ops for kernels that we do recognize. The whole thing doesn’t break because of an unknown op! Great!

I think that this roundripping then can be very succinctly described as the specific question “what do we do about the torch.kernel_call’s that we didn’t lower to aten?”. If we carve an e2e path that supports having those fall back to ATen, then that is effectively a roundtrip. For example, we could extend the reference backend to support passing through torch.kernel_call’s back to ATen.

(And if desired, we can have a tiny pass that looks at aten operations and “erases” them back to torch.kernel_call ops by just looking at the op name).

Does that make sense? I have no idea how quantized kernels / conv params like we discussed fit into that picture. Sorry!

As an idea on the ATen side, we may be able to piggyback off the ATen Dispatcher similar to this example from the PyTorch docs:

I think we might be better off creating a TorchTensor type though (like @stellaraccident was hinting at in the previous PR) so that we don’t lose any important TorchTensor context while lowering.

FWIW, I’ve attempted a “round-tripping” on my end by writing some ConvertTensorToMemRef and ConvertMemRefToTensor wrappers, and that didn’t work out as well as I had hoped. Some of the problems I ran into were NCHW / NHWC conversions at the boundaries (not bad, but annoying to be messing with strides / views at this stage), and quant constants living on the Tensor object itself (when returning a <memref?xi8> from mlir -> pytorch, the quant information is lost and must be “added” back to the pytorch Tensor object somehow). I’m sure similar problems arise when looking at gradients, but I haven’t messed with those yet.

Yeah, losing that information at the first step is fairly deadly. You can always lower to a more constrained form or special case (say, identifying an island that just needs to “embrace the weird” and call out to Torch), but you have none of those options if coming in at the wrong level.

Of the types that are currently in-repo, the numpy.ndarray type is the closest in that it can represent most forward-mode dispatches, and this is what I am currently using on import. I think there should be a TorchTensor as well (and corresponding conversions, accessors, etc), and if doing that, it makes sense to design it properly (which is why I didn’t undertake it ad-hoc in the first step). In my opinion, any torch round-tripping from the torch dialect is going to need to work out the type story on this in order to be real.

From a torch dialect perspective, I am fairly certain that the built-in tensor and memref types are things that we lower to as possible/needed.

1 Like