Hi,
We’re brainstorming how to represent a reshape operation in TCP. One possibility is to model reshapes using tensor.collapse_shape
+ tensor.expand_shape
, which would obviate the need for a dedicated tcp.reshape
op. However, we ran into a problem with the current tensor.expand_shape
spec and wanted to get your take on a potential solution.
In particular, we believe the current specification for tensor.expand_shape
is incomplete. For instance, this is the example in the current spec:
%b = tensor.expand_shape %a [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x?x?xf32>
As written, I don’t see a way to unambiguously determine the output shape given the input tensor’s (%a
) shape. Moreover, this means we cannot lower fully dynamic reshapes like the following to collapse_shape / expand_shape:
%shape = torch.prim.ListConstruct %dim0, %dim1
: (!torch.int, !torch.int) -> !torch.list<int>
%reshape = torch.aten.reshape %t, %shape
: !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>
-> !torch.vtensor<[?, ?],f32>
One path forward is to inherit the memref.expand_shape
constraint – that at most one dimension in each reassociation group is dynamic. While this will complete the spec, it won’t generalize to fully dynamic graphs like the one above.
Given that, I propose to add N additional index
scalar operands, prettified as dims
, to tensor.expand_shape
, one for each non-trivially expanded dimension in the reassociation list. I.e. if an input dimension is expanded into K output dimension then there are K index
operands corresponding to each of the expanded output dimensions.
Here’s a full example:
%b = tensor.expand_shape %a [[0, 1], [2], [3, 4, 5]]
dims[%d0, %d1, %d2, %d3, %d4]
: tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
Here if %a
’s shape was 12x6x28
and the d
… values were 3
, 4
, 2
, 7
, 2
then %b
s shape will be 3x4x6x2x7x2
.
I believe this isn’t novel, as @MaheshRavishankar has suggested this as the path forward in 1:1 conversations.
When lowering from operations like torch.reshape(vect, [-1, d])
we’ll have to insert the necessary computations to figure out the actual extent of the -1
dimension before the tensor.expand_shape
. This keeps the tensor.expand_shape
semantics simple at the cost of verbose IR. If the verbosity becomes an issue, we can perhaps create a tensor.compute_expanded_shape
op that hides the scalar math and computes the unknown dimension in one step.
One dims
input is redundant for each reassociation group since the product of the dimensions in the output shape has to match the size of the input shape. However, requiring all of the dimension bounds has some advantages:
- By making the computation of the last dimension explicit in the graph (in cases where it needs to be computed) we can now separately optimize this computation. This would not be possible if this computation was implicit in the semantic of
expand_shape
. - This makes each dimension uniform instead of creating a special category of dimension, which would either have to be hardcoded to be the last dimension in each reassociation group or be denoted by an attribute.
Finally, if the product of the expanded dimensions doesn’t match the unexpanded dimension (e.g. if %d0
was 5
in the example above) then the operation has undefined behavior. This means in general the operation cannot be speculatively executed.