TOSA shape inference and dynamic shape support

Recent triaging on end-to-end performance for TOSA has encountered a couple of common issues, namely surrounding shape inference and dynamic shapes. It is common for model sources to under specify shapes at import that could be statically determine. While it is preferable to have shapes already computed by source, it is unlikely always have fully constrained types and would rely on dynamic shape support to handle these cases.

This raises the question about TOSA having a method for propagating / performing shape inference on its IR, allowing an program containing only TOSA / StandardOps to fully determine what static dimensions are possible. The benefit is during codegen or lower level compilation, more efficient code could implemented by avoiding runtime level shape inference.

NPComp has been recently working on the shape propagation problem, headed by @_sean_silva, which may be able to provide a targeted way to approach the problem.

1 Like

Hi Rob,

Yeah, the pass is RefineTypes (pass, test) and it would be fairly trivial to adapt to TOSA (main work is defining the transfer functions).

We also have a pretty good recipe in npcomp for the lowering of error-triggering ops with dynamically shaped (but ranked) operands being lowered to linalg-on-tensors + std ops. It would be pretty mechanical to build out a mini version of that upstream for TOSA, giving us a full dynamic-shape-aware flow lowering to linalg-on-tensors. If there’s pull from the TOSA community for dynamicity of that kind, I would be happy to spearhead an initial spike of that. To give a flavor of what that looks like, we can turn the following Python code (just a single torch.mm), into the resulting linalg-on-tensors + std shown below it. (I’m glossing over some of the Python/Torch-specific details that are not relevant to TOSA)

    @annotate_args([
        None,
        ([-1, -1], torch.float32),
        ([-1, -1], torch.float32),
    ])
    def forward(self, lhs, rhs):
        return torch.mm(lhs, rhs)
module  {
  func @forward(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %cst = constant 0.000000e+00 : f32
    %c1 = constant 1 : index
    %c0 = constant 0 : index
    %0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
    %1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
    %2 = memref.dim %arg1, %c0 : tensor<?x?xf32>
    %3 = memref.dim %arg1, %c1 : tensor<?x?xf32>
    %4 = cmpi eq, %1, %2 : index
    assert %4, "mismatching contracting dimension for torch.aten.mm"
    %5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
    %6 = linalg.fill(%5, %cst) : tensor<?x?xf32>, f32 -> tensor<?x?xf32> 
    %7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>
    return %7 : tensor<?x?xf32>
  }
}

@sjarus we have been discussing this offline.

I would love to get something like this landed for tosa. Would be great to land something simple here to get the ball rolling and see the rest of the design space.

Summarizing a significant amount of conversation with @stellaraccident on this:

a) Bit-accurate verification of the TOSA form (using the TOSA reference model) needs a shape-resolved pure TOSA form to feed the ref model.
b) Code generation can pass through dynamic shapes when working from a frontend (e.g. TF and TFLite currently).

Originally, we primarily focused on (a), but now that we have a stable infrastructure (b) is increasingly important. So far we’ve depended on the model conditioning offering static shaped networks, or used --tf-enable-shape-inference-on-import together with known input shapes to fully resolve the network prior to legalization to TOSA.

@stellaraccident mentioned that it’s undesirable to fix shapes at the frontend level, and that’s something we agree with. Until recently, TOSA dialect form was more restrictive than the spec mandate, something I’d fixed in https://reviews.llvm.org/D102958 .

I’ll start a separate thread on the TOSA spec mlplatform to discuss whether additional ops are worth freeing up, but the current set of changes should be relatively flexible for a lot of cases, e.g. there are no constraints on eltwise ops.

The legalizations may need local patches to fix corner cases where we don’t handle dynamic shaped content cleanly - there was at least one network that failed as a result. This would be on the TensorFlow side, at least for the TF/TFL to TOSA legalizations.

Having said this, we might be able to satisfy both (a) and (b) requirements above from a TOSA form containing dynamic shapes, with a shape inference pass ilke the TF infer shape pass, that takes input and output node information and in the input shapes in order to generate a shape resolved form. Thus:

Path 1: Resolve shapes at frontend, carry through into TOSA (current mode),
Path 2: Ensure legalizations can convey dynamic shapes to general TOSA form, implement pass that can take input shapes and generate shape resolved TOSA form where needed. This can be used to drive the TOSA regression/conformance testing. In a more dynamic e2e stack this resolution might be deferred to a combination of static and dynamic resolution.

1 Like

Thanks @sjarus.

To be clear, when you say “resolve shapes” / “shape-resolved” you mean “make the program fully statically shaped”?

There is also the --tf-shape-inference pass, it runs later and uses the infer return type trait at the moment (I have been meaning to expand it to the shaped return type one too, but I wanted to update that interface first as currently it requires all materialized). These two approaches (infer on import vs as pass) use 2 different approaches but share some of the underlying shape functions.

Yes, I think this allows static to fall out of optimizing dynamic. The reify methods are towards this direction. And is a bit similar to what is done in KernelGen too.

1 Like

Yeah, the reference model doesn’t know what to do with ? dims - it’s not a runtime but just a graph executor that expects to consume a graph with fully statically-resolved shapes.

Right now we take care of it by just consuming content that’s statically shape resolved during model conditioning (mostly tflite - the TFLite-to-MLIR translator doesn’t do shape resolution) or using --tf-enable-shape-inference-on-import .

The existing legalizations conveniently assume this but they need not do so - a set of generalized legalizations with a TOSA-level shape inference pass feeding in input shapes and node names would also work.

This specific pass is still using the tf shape functions, right? That is a couple of levels too late (and doesn’t help with specializing tflite derived tosa programs, etc). But are the mechanics of the pass able to be repurposed? (I haven’t looked at that one).

Neither have I, but given our conversation and @jpienaar’s input, it sounds like the general mechanics could indeed be used to feed in input/output node names + shapes and generate a fully statically-shaped form of the network.

This would indeed be a pass as opposed to a translation time task as --tf-enable-shape-inference-on-import does.

@jpienaar @_sean_silva Looking at the shape inference passes of both TF and NPComp, over all they are pretty similar, both computing the expect shape type of each operation and resolving the final type. The main differences being TF’s multiple iterations to propagate types and using the InferReturnType to compute an individual ops shape propagation.

Overall NPComp’s implementation is a little cleaner / more concise, likely due to the TF pass working around some TF-isms. Using InferReturnType to compute the result type during shape propagation feels like a much cleaner implementation than branching on operation types within the pass. It would also allow us to support shape propagation for non-tosa types more easily in the future.

My main concern for either of these passes is multiple uses of a single under defined function. In some cases we could better propagate shapes by guaranteeing each function has a unique use, to avoid under propagating known shapes. Have either of you encountered this in the past?

Yes, see GuaranteeAllFuncsOneUse (pass, test)

Definitely, TF one needs a cleanup but unfortunately it worked “well enough” that wasn’t high priority (but really need to clean it up). Currently being used in a couple of different workflows. But it needs an update. And data flow framework would make it much nicer I think (I also want to roll in the tensor list part into the main loop with that).

+1, conceptually we could try doing it lazily but I found folks didn’t like the shape inference pass doing whole graph mutations :slightly_smiling_face: