Now we just need to add more patterns in TorchToTosa.cpp
Two points of discussion
I was first going to implement torch.aten.mm in terms of MATMUL, but I found that MATMUL takes a leading batch dimension. I couldn’t find a way in TOSA to do a dynamically shaped “insert a leading 1 dimension” operation, so I ended up just going with TANH.
It’s fine if we want to limit this to static shapes, though it seems like we aspire for TosaToLinalg to support dynamic shapes (Rob said that many ops already do), then we probably want this layer to support it as well.
Error handling / undefined behavior
This is largely related to dynamic shapes actually. Consider the case of a matmul with dynamic shapes, and the K dimension mismatches at runtime. What is the TOSA program specced to do? I see two options:
A. The program must safely report an error and return control to the caller
B. Undefined behavior (it might scribble over random memory, format your hard drive, etc.)
Linalg-on-tensors has behavior B. So if we want TOSA to support A., then TosaToLinalg needs to be inserting error guards (which it doesn’t currently do).
We have to broach a similar topic with respect to broadcasting behavior for elementwise ops (when a ? turns out to be 1 at runtime and broadcasts against the opposite dimension). In TorchToLinalg and MHLO-to-Linalg, we use a strategy where any dynamic size-1 broadcast is a runtime error, but handle cases where the dimension is statically 1 (the code in TorchToLinalg is here). This is not super principled (type inference can convert a program from a runtime error to a success), but has worked well in practice and seems reasonable to adopt in TOSA.
If we want to make TOSA closed for these cases, we may want to add some additional tosa.util.assert_equal(index, index) ops or equivalent. Emitting those as part of lowering into TOSA would allow backends the flexibility to handle them how they want (ignore, analyze/hoist/eliminate/etc). If these ops existed, we could just have a basic pass upstream that removed them, letting folks easily get a program that is “just TOSA”, albeit one where it is up to them to only pass legal shapes.
We could instead define a tosa.util.assert which has a single-block region that yields an i1. Then we could mark that recursively legal or verify it in some way without polluting the main constraints on what is legal. We could then provide standard passes to either eliminate asserts or inline them to loose ops and a std.assert. Seems like such a construct would let anyone do whatever they need.
Sorry for the late response. It took me a while to pick my jaw up from the floor after seeing @_sean_silva’s code review
A couple of intersecting plans that matter here:
We plan to move the TOSA legalize_common stuff in TensorFlow to MLIR codebase. Ideally this would interface within MLIR code with TosaInferShapes and any dynamic shape expression dialect such as one Stella refers to:
If this sounds ok, then the first part is just moving things over to the MLIR side and then making the legalizations flexible enough to carry dynamic shapes cleanly. We believe this setup would be more suited to long term support for TOSA as a dialect being fed from multiple frameworks - TF and Torch for the moment.
There’s a planned RFC for a set of TOSA utility ops for statefulness support. We hope this will align well enough with existing semantics to me a simple fit in. I’ll be putting out the RFC hopefully over weekend.
Regarding specifically implementing the symbolic shapes proposal as a tosa.util form, that seems ok, but I had raised a particular question in the original proposal discussion - these APIs interface between compiled content and runtime interaction. That implies some kind of runtime support that bakes in tosa.util as an interfacing point, e.g. to transmit runtime getDim()s. It would be helpful to understand how the proposal addresses fully compile time resolvable shapes and runtime-resolved ones.
@_sean_silva I’ll give your PR a careful look on Friday, thank you so. much for this enormous starting contribution. It saves me the effort of bootstrapping this.
It works internally with the Torch-MLIR repo (and the TensorFlow repo since quite some time ago) as a submodule, together with a unit test generator that emits legalization input permutation tests to confirm that the TOSA form generates bit-accurate results vs original reference Torch output.
The unit test harness is Python based and is similar to the existing e2e infrastructure except that it emits a set of shape permutations (reference model doesn’t do shape inference) and exercises each permutation per-op for bit accuracy.
In addition to ensuring legalization fidelity, we’ve found it very useful to shake out corner cases issues especially around conv, pooling and other shape sensitive ops. We’d be happy to help integrate it into Torch-MLIR CI is you’re interested.