Background
During a dynamic shape code generation, we always want to find the relationships among the dynamic shapes. The relationships include and not limited to:
- Where’s the dimension’s dynamic comes from? For the example below, we’d like to deduce that all of %0, %1, %3’s dynamic is from %arg0.
func.func @several_ops(%arg0: tensor<?x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) -> tensor<?x4xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #mhlo.dot<
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [0]
>,
precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
} : (tensor<?x4xf32>, tensor<4x4xf32>) -> tensor<?x4xf32>
%1 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<2xindex>) -> tensor<?x4xf32>
%3 = mhlo.add %0, %2 : tensor<?x4xf32>
return %3 : tensor<?x4xf32>
}
- The constraints on the dynamic shapes. Let’s say dimension
a
and dimensionb
are both dynamic. We’d like to deduce that whethera
equalsb
, ora + b
equals a static dimension, etc.
To address these problems, I’m considering to implement a generic symbolic shape analysis.
Related Work
I found an interesting pass in mlir-hlo
repo. It uses affine expression to describe how a dynamic dimension could be calculated. But I’m afraid we’ll do a lot of repetitive work if symbolic shape analysis is based on affine expresion. Since:
- We cannot reuse the
reifyReturnTypeShapes
method inInferShapedTypeOpInterface
- If we want to do some simplification or do more analysis on the symbolic expression, we could not reuse the existing pass(Ex. cse, canonicalize) or write new passes based on the MLIR pass framework.
Proposed Solution
Here I drafted a new solution to express symbolic shape calculation. It works in 3 steps:
- Create an auxiliary symbolic shape inference function for each original function.
- Run shape-reification pass on the created auxiliary functions.
- Analyze based on the auxiliary functions.
We still use the IR mentioned in background section as an input IR. After step 1, an auxiliary function will be created as below. Where the arguments of the auxiliary function is the same as the original one, and the results includes two parts. In the example there are 4 intermediate results, %0, %1, %2, %3. Then the return results will be [shape_of %0, shape_of %1, shape_of %2, shape_of %3, %0, %1, %2, %3].
func.func private @_shape_infer_several_ops(%arg0: tensor<?x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>) attributes {auxiliary_of = "several_ops"} {
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<?x4xf32>, tensor<4x4xf32>) -> tensor<?x4xf32>
%1 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
%2 = shape.value_as_shape %1 : tensor<2xindex> -> !shape.shape
%3 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
%4 = shape.shape_of %3 : tensor<2xindex> -> tensor<1xindex>
%5 = shape.value_as_shape %4 : tensor<1xindex> -> !shape.shape
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<2xindex>) -> tensor<?x4xf32>
%7 = shape.shape_of %6 : tensor<?x4xf32> -> tensor<2xindex>
%8 = shape.value_as_shape %7 : tensor<2xindex> -> !shape.shape
%9 = mhlo.add %0, %6 : tensor<?x4xf32>
%10 = shape.shape_of %9 : tensor<?x4xf32> -> tensor<2xindex>
%11 = shape.value_as_shape %10 : tensor<2xindex> -> !shape.shape
return %2, %5, %8, %11, %0, %3, %6, %9 : !shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>
}
And after step 2, a shape-reification pass and a cse pass will be run on the auxilary function, producing the following IR:
func.func private @_shape_infer_several_ops(%arg0: tensor<?x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>) attributes {auxiliary_of = "several_ops"} {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%0 = shape.const_shape [2] : tensor<1xindex>
%1 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<?x4xf32>, tensor<4x4xf32>) -> tensor<?x4xf32>
%2 = tensor.dim %arg0, %c0 : tensor<?x4xf32>
%3 = tensor.from_elements %2, %c4 : tensor<2xindex>
%4 = shape.value_as_shape %3 : tensor<2xindex> -> !shape.shape
%5 = shape.value_as_shape %0 : tensor<1xindex> -> !shape.shape
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<2xindex>) -> tensor<?x4xf32>
%7 = mhlo.add %1, %6 : tensor<?x4xf32>
return %4, %5, %4, %4, %1, %3, %6, %7 : !shape.shape, !shape.shape, !shape.shape, !shape.shape, tensor<?x4xf32>, tensor<2xindex>, tensor<?x4xf32>, tensor<?x4xf32>
}
Finally, we could add some analysis based on the auxiliary function. Note that here the symbolic shape sources should be Values in the original function, not ones in the auxiliary function.
============= symbolic expr sources table for @several_ops =============
original value: %0 = "mhlo.dot_general"...
symbolic shape sources:
<block argument> of type 'tensor<?x4xf32>' at index: 0
original value: %1 = shape.shape_of %0 : tensor<?x4xf32> -> tensor<2xindex>
symbolic shape sources:
original value: %2 = "mhlo.dynamic_broadcast_in_dim"...
symbolic shape sources:
<block argument> of type 'tensor<?x4xf32>' at index: 0
original value: %3 = mhlo.add %0, %2 : tensor<?x4xf32>
symbolic shape sources:
<block argument> of type 'tensor<?x4xf32>' at index: 0
Questions
- How to express constraints on the dynamic shapes? Since the auxiliary function in my current solution only address the calculation of a dynamic shape. The constraints could be useful if we meet cases like this:
a : [?, 100]
b : [?, 100]
c = concat(a, b, axis=0)
And if we know a.shape[0] + b.shape[0] = 1024
from annotation or deduction on the previous operations. Then we could know c has a static shape of [1024, 100].
-
I found the
reifyReturnTypeShapes
could be implemented in a very flexible way. Instead ofshape.shape_of
, developers could usetensor.dim
op (or any other ops I don’t know) to get the actual dimension of the operands’ shapes. Therefore I have to a new DimOp pattern to Shape Reification pass. Is it possible to limit the flexibility ofreifyReturnTypeShapes
interface? -
How to map the Value in the auxiliary function back to the original function? Since some passes will run on the auxiliary functions, which adds difficulty to the mapping.