Hi,

We’re brainstorming how to represent a reshape operation in TCP. One possibility is to model reshapes using `tensor.collapse_shape`

+ `tensor.expand_shape`

, which would obviate the need for a dedicated `tcp.reshape`

op. However, we ran into a problem with the current `tensor.expand_shape`

spec and wanted to get your take on a potential solution.

In particular, we believe the current specification for `tensor.expand_shape`

is incomplete. For instance, this is the example in the current spec:

```
%b = tensor.expand_shape %a [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x?x?xf32>
```

As written, I don’t see a way to unambiguously determine the output shape given the input tensor’s (`%a`

) shape. Moreover, this means we cannot lower fully dynamic reshapes like the following to collapse_shape / expand_shape:

```
%shape = torch.prim.ListConstruct %dim0, %dim1
: (!torch.int, !torch.int) -> !torch.list<int>
%reshape = torch.aten.reshape %t, %shape
: !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>
-> !torch.vtensor<[?, ?],f32>
```

One path forward is to inherit the `memref.expand_shape`

constraint – that at most one dimension in each reassociation group is dynamic. While this will complete the spec, it won’t generalize to fully dynamic graphs like the one above.

Given that, I propose to add N additional `index`

scalar operands, prettified as `dims`

, to `tensor.expand_shape`

, one for each non-trivially expanded dimension in the reassociation list. I.e. if an input dimension is expanded into K output dimension then there are K `index`

operands corresponding to each of the expanded output dimensions.

Here’s a full example:

```
%b = tensor.expand_shape %a [[0, 1], [2], [3, 4, 5]]
dims[%d0, %d1, %d2, %d3, %d4]
: tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
```

Here if `%a`

’s shape was `12x6x28`

and the `d`

… values were `3`

, `4`

, `2`

, `7`

, `2`

then `%b`

s shape will be `3x4x6x2x7x2`

.

I believe this isn’t novel, as @MaheshRavishankar has suggested this as the path forward in 1:1 conversations.

When lowering from operations like `torch.reshape(vect, [-1, d])`

we’ll have to insert the necessary computations to figure out the actual extent of the `-1`

dimension before the `tensor.expand_shape`

. This keeps the `tensor.expand_shape`

semantics simple at the cost of verbose IR. If the verbosity becomes an issue, we can perhaps create a `tensor.compute_expanded_shape`

op that hides the scalar math and computes the unknown dimension in one step.

One `dims`

input is redundant for each reassociation group since the product of the dimensions in the output shape has to match the size of the input shape. However, requiring all of the dimension bounds has some advantages:

- By making the computation of the last dimension explicit in the graph (in cases where it needs to be computed) we can now separately optimize this computation. This would not be possible if this computation was implicit in the semantic of
`expand_shape`

. - This makes each dimension uniform instead of creating a special category of dimension, which would either have to be hardcoded to be the last dimension in each reassociation group or be denoted by an attribute.

Finally, if the product of the expanded dimensions doesn’t match the unexpanded dimension (e.g. if `%d0`

was `5`

in the example above) then the operation has undefined behavior. This means in general the operation cannot be speculatively executed.