We’re brainstorming how to represent a reshape operation in TCP. One possibility is to model reshapes using
tensor.expand_shape, which would obviate the need for a dedicated
tcp.reshape op. However, we ran into a problem with the current
tensor.expand_shape spec and wanted to get your take on a potential solution.
In particular, we believe the current specification for
tensor.expand_shape is incomplete. For instance, this is the example in the current spec:
%b = tensor.expand_shape %a [[0, 1], ] : tensor<?x?xf32> into tensor<?x?x?xf32>
As written, I don’t see a way to unambiguously determine the output shape given the input tensor’s (
%a) shape. Moreover, this means we cannot lower fully dynamic reshapes like the following to collapse_shape / expand_shape:
%shape = torch.prim.ListConstruct %dim0, %dim1 : (!torch.int, !torch.int) -> !torch.list<int> %reshape = torch.aten.reshape %t, %shape : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?, ?],f32>
One path forward is to inherit the
memref.expand_shape constraint – that at most one dimension in each reassociation group is dynamic. While this will complete the spec, it won’t generalize to fully dynamic graphs like the one above.
Given that, I propose to add N additional
index scalar operands, prettified as
tensor.expand_shape, one for each non-trivially expanded dimension in the reassociation list. I.e. if an input dimension is expanded into K output dimension then there are K
index operands corresponding to each of the expanded output dimensions.
Here’s a full example:
%b = tensor.expand_shape %a [[0, 1], , [3, 4, 5]] dims[%d0, %d1, %d2, %d3, %d4] : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
%a’s shape was
12x6x28 and the
d… values were
%bs shape will be
I believe this isn’t novel, as @MaheshRavishankar has suggested this as the path forward in 1:1 conversations.
When lowering from operations like
torch.reshape(vect, [-1, d]) we’ll have to insert the necessary computations to figure out the actual extent of the
-1 dimension before the
tensor.expand_shape. This keeps the
tensor.expand_shape semantics simple at the cost of verbose IR. If the verbosity becomes an issue, we can perhaps create a
tensor.compute_expanded_shape op that hides the scalar math and computes the unknown dimension in one step.
dims input is redundant for each reassociation group since the product of the dimensions in the output shape has to match the size of the input shape. However, requiring all of the dimension bounds has some advantages:
- By making the computation of the last dimension explicit in the graph (in cases where it needs to be computed) we can now separately optimize this computation. This would not be possible if this computation was implicit in the semantic of
- This makes each dimension uniform instead of creating a special category of dimension, which would either have to be hardcoded to be the last dimension in each reassociation group or be denoted by an attribute.
Finally, if the product of the expanded dimensions doesn’t match the unexpanded dimension (e.g. if
5 in the example above) then the operation has undefined behavior. This means in general the operation cannot be speculatively executed.