[RFC] Add explicit shape inputs to tensor.expand_shape

Hi,

We’re brainstorming how to represent a reshape operation in TCP. One possibility is to model reshapes using tensor.collapse_shape + tensor.expand_shape, which would obviate the need for a dedicated tcp.reshape op. However, we ran into a problem with the current tensor.expand_shape spec and wanted to get your take on a potential solution.

In particular, we believe the current specification for tensor.expand_shape is incomplete. For instance, this is the example in the current spec:

%b = tensor.expand_shape %a [[0, 1], [2]]
  : tensor<?x?xf32> into tensor<?x?x?xf32>

As written, I don’t see a way to unambiguously determine the output shape given the input tensor’s (%a) shape. Moreover, this means we cannot lower fully dynamic reshapes like the following to collapse_shape / expand_shape:

%shape = torch.prim.ListConstruct %dim0, %dim1
  : (!torch.int, !torch.int) -> !torch.list<int>
%reshape = torch.aten.reshape %t, %shape
  : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>
  -> !torch.vtensor<[?, ?],f32>

One path forward is to inherit the memref.expand_shape constraint – that at most one dimension in each reassociation group is dynamic. While this will complete the spec, it won’t generalize to fully dynamic graphs like the one above.

Given that, I propose to add N additional index scalar operands, prettified as dims, to tensor.expand_shape, one for each non-trivially expanded dimension in the reassociation list. I.e. if an input dimension is expanded into K output dimension then there are K index operands corresponding to each of the expanded output dimensions.

Here’s a full example:

%b = tensor.expand_shape %a [[0, 1], [2], [3, 4, 5]]
    dims[%d0, %d1, %d2, %d3, %d4]
    : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>

Here if %a’s shape was 12x6x28 and the d… values were 3, 4, 2, 7, 2 then %bs shape will be 3x4x6x2x7x2.

I believe this isn’t novel, as @MaheshRavishankar has suggested this as the path forward in 1:1 conversations.

When lowering from operations like torch.reshape(vect, [-1, d]) we’ll have to insert the necessary computations to figure out the actual extent of the -1 dimension before the tensor.expand_shape. This keeps the tensor.expand_shape semantics simple at the cost of verbose IR. If the verbosity becomes an issue, we can perhaps create a tensor.compute_expanded_shape op that hides the scalar math and computes the unknown dimension in one step.

One dims input is redundant for each reassociation group since the product of the dimensions in the output shape has to match the size of the input shape. However, requiring all of the dimension bounds has some advantages:

  1. By making the computation of the last dimension explicit in the graph (in cases where it needs to be computed) we can now separately optimize this computation. This would not be possible if this computation was implicit in the semantic of expand_shape.
  2. This makes each dimension uniform instead of creating a special category of dimension, which would either have to be hardcoded to be the last dimension in each reassociation group or be denoted by an attribute.

Finally, if the product of the expanded dimensions doesn’t match the unexpanded dimension (e.g. if %d0 was 5 in the example above) then the operation has undefined behavior. This means in general the operation cannot be speculatively executed.

1 Like

Yup that is a known issue… Today it does support the case where one dynamic extent can be expanded into a single dynamic extent and other static extents. So this is valid (or should be unless there is a bug)

%b = tensor.expand_shape %a [[0, 1], [2]]
  : tensor<?x?xf32> into tensor<?x4x?xf32>

Overall agree with the proposal. Could we just make it simpler by adding a list of mixed static and dynamic values (like is done for tensor.extract_slice, tensor.empty, etc.) that just specifies the output shape. Much lesser cognitive overhead. For static cases we can verify that the expansions are well defined. (Sorry if that is what you suggested and I didnt gather that)

More concretely, the above example would be

%b = tensor.expand_shape %a[[0, 1], [2]] output_shape [%d0, 4, %d1] : tensor<?x?xf32>, into tensor<?x4x?xf32>

%d1 is the same as tensor.dim %a, 1 and %d0 is (tensor.dim %a, 0 / 4) and tensor.dim %a, 0 has to be a multiple of 4 (if not its UB)

2 Likes

This sounds like a great idea! I think it will make the IR more readable.

Is there a standard way to express this kind of “mixed” inputs, i.e. where some inputs are compile-time constant and some are SSA values?

(Changing topics.)

I suspect we’ll need to eventually allow even statically known malformed cases to robustly support control flow.

E.g. let’s say the input program was:

a = something() # Shape is <?xf32>
if a.shape[0] % 4 == 0:
  return torch.reshape(a, (4, -1))
else:
  return torch.reshape(a, (1, -1))

As written the program is fine. The torch.reshape(a, (4, -1)) is an expand shape going from <?xf32> to <4x?xf32> which is legit.

Now let’s say we constant fold something() and deduce that a’s shape is <7xf32>. After this constant folding, the first torch.reshape will fail verification since it is now expanding a <7xf32> to <4x1xf32>.

If we want to keep failing the verifier in such cases then we’d have to essentially mandate DCE. That is, as soon as a is constant folded, we also have to eliminate the if a.shape[0] % 4 == 0 branch to keep the program well-formed. However, IME this introduces a weird and complex coupling between local transformations like constant folding (and other such analyses) and global transformations like DCE.

You can look at tensor.extract_slice . You have a static array attribute list of index types, and a SSA value list of dynamic values. The static array list is the same size as the result, and ShapedType::kDynamicSize value indicates that the shape is dynamic. There should be as many entries in the SSA value list as the number of ShapedType::kDynamicSize. All of these have assembly format and verifier helpers. tensor.pad does something similar for representing low and high padding values.

That seems like wrong IR and should fail verification… I’d rather fail verification on those so that we dont have compilation errors.

I also think that splitting the support for -1 out of the expand_shape operation is the right thing to do. expand_shape lives until fairly late in the pipeline and the -1 dimension is merely a front-end convenience to make it easier to specify this behavior in my mind.

StableHLO also models this via a separate compute_reshape_shape operation (defined here).

Yes, this is a well-known limitation, it is unfortunate the doc is misleading there … I added you to tensor.expand_shape cannot expand to multiple dynamic dimensions · Issue #51564 · llvm/llvm-project · GitHub where we are tracking the limitation. It has not been critical enough for us yet, so thank you for tackling this!

We have a more general use for this type behavior across ops. I was thinking of making an interface for such behavior that could come with a nicer printer/parser, inspired by what we can see in IREE tests:

flow.dispatch @executable_2::@dispatch[%arg0_dim0](%arg0, %arg0_dim0, %arg1, %arg1_dim0) : 
  (tensor<?xf32>{%arg0_dim0}, index, tensor<?xf32>{%arg1_dim0}, index) 
  -> tensor<?xf32>{%arg0_dim0}

Implementing this behavior as a standardized reusable interface would be very useful.
Could we rope you onto this path?

Yes, +1. Also, as a standardized interface will be very useful.

Thanks for the feedback everyone! I’ll tackle this next after I’m done with the UB/speculation patch series.

I can give it a shot. Can you share a pointer to how IREE does this?

Btw, by “interface” did you mean literally an op interface? My intuition was that a trait would be better suited; I don’t see when we’d want to use this functionality where we don’t also know the op type.

In general with folding (in the strict MLIR FooOp::fold sense of the word), you do not change the static type of the Value. In Torch-MLIR we have a trait AllowsTypeRefinement that lets us know when it is safe to refine the types like this. But all these “mixed static and dynamic”-type ops do not allow type refinement. So if you folded this, you would actually be inserting a tensor.cast to erase the 7 back into a ? for any users. After DCE of the dead branch, we would have some sort of canonicalizer that “absorbs” the tensor.cast into the mixed-static-and-dynamic op (I believe we have prior art for this in Linalg).

What’s the reason behind this? Is it just to avoid issues like the one I shared above or are there other reasons besides that?

AFAIK it’s just to avoid such issues. There is no general way to know, for a given Value in MLIR, what might need to be updated to avoid verifier-invalid IR if its type changes – in the extreme case, it could be the whole program (imagine a Value that is a successor arg, or a call arg, requiring updating the types in the successor or callee).

I think what Nicolas suggested is a bit more heavy weight. Its basically on the path towards dependent types. I think that should be part of a larger effort that can be decoupled from the need here. Just re-iterating having a “mixed static dynamic” list is probably a good step for now.

1 Like

It looks like there are a few places that generically work across tensor.expand_shape and memref.expand_shape, e.g. here. So this means I’ll have to do one of the following:

  1. Also update memref.expand_shape to have explicit shape inputs like tensor.expand_shape.
  2. Bifurcate the common code into a tensor version and a memref version.

Currently I’m leaning towards (1), but would like to get some feedback before spending too much time on this.

(Edit: with (2) we’ll still be able to share some code, I meant bifurcation only in the conceptual sense.)

Definitely 1) please.

The only reason we have 2 versions of these ops is that we do not have a dialect that can dependent on both tensor and memref. It is important the semantics and functionality remain very close (e.g. for bufferization).

I’d love to see such a dialect be available in MLIR so we can avoid the duplication.

+1 to (1).