I think that when we actually code this up, the clean_for_*
will turn into (either being lowered to, or literally start out that way) into open-coded shape calculations. So concretely it will look like:
%result0 = "frontend.add"(%unclean_lhs, %unclean_rhs) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%result = "frontend.add"(%result0, %unclean_rhs2) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
will expand first to something like this. Note that it is just the same fully-general pattern inserted twice, once for each frontend.add op.
%ul_s = shape.shape_of %unclean_lhs
%ur_s = shape.shape_of %unclean_rhs
%broadcasted0_s = "shape.broadcast_shape"(%ul_s, %ur_s)
"shape.assert_no_error"(%broadcasted0_s)
%clean_lhs = "tcp.broadcast_to"(%unclean_lhs, %broadcasted0_s)
%clean_rhs = "tcp.broadcast_to"(%unclean_lhs, %broadcasted0_s)
%result0 = "tcp.add"(%clean_lhs, %clean_rhs) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%r0_s = shape.shape_of %result0
%ur2_s = shape.shape_of %unclean_rhs2
%broadcasted_s = "shape.broadcast_shape"(%r0_s, %ur2_s)
"shape.assert_no_error"(%broadcasted_s)
%clean_result0 = "tcp.broadcast_to"(%result0, %broadcasted_s)
%clean_rhs2 = "tcp.broadcast_to"(%unclean_rhs2, %broadcasted_s)
%result = "tcp.add"(%clean_result0, %clean_rhs2) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
We will then run a pass that uses shape transfer functions on tcp.add and tcp.broadcast_to to RAUW the %result0_s with %broadcasted0_s, breaking the data dependency between the two tcp.broadcast_to/tcp.add ops.
%ul_s = shape.shape_of %unclean_lhs
%ur_s = shape.shape_of %unclean_rhs
%broadcasted0_s = "shape.broadcast_shape"(%ul_s, %ur_s)
"shape.assert_no_error"(%broadcasted0_s)
%clean_lhs = "tcp.broadcast_to"(%unclean_lhs, %broadcasted0_s)
%clean_rhs = "tcp.broadcast_to"(%unclean_lhs, %broadcasted0_s)
%result0 = "tcp.add"(%clean_lhs, %clean_rhs) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%ur2_s = shape.shape_of %unclean_rhs2
%broadcasted_s = "shape.broadcast_shape"(%broadcasted0_s, %ur2_s)
"shape.assert_no_error"(%broadcasted_s)
%clean_result0 = "tcp.broadcast_to"(%result0, %broadcasted_s)
%clean_rhs2 = "tcp.broadcast_to"(%unclean_rhs2, %broadcasted_s)
%result = "tcp.add"(%clean_result0, %clean_rhs2) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
Then we use:
- the fact that shape.shape_of and shape.broadcast_shape is NoSideEffects to hoist it
- the fact that tcp.broadcast_to and tcp.add have UB for the error cases to justify hoisting the second shape.assert_no_error past the first block of tcp.broadcast_to/tcp.add
- use the UB behavior of tcp.broadcast_to and tcp.add to hoist “tcp.broadcast_to”(%unclean_rhs2, %broadcasted_s)
%ul_s = shape.shape_of %unclean_lhs
%ur_s = shape.shape_of %unclean_rhs
%ur2_s = shape.shape_of %unclean_rhs2
%broadcasted0_s = "shape.broadcast_shape"(%ul_s, %ur_s)
%broadcasted_s = "shape.broadcast_shape"(%broadcasted0_s, %ur2_s)
"shape.assert_no_error"(%broadcasted0_s)
"shape.assert_no_error"(%broadcasted_s)
%clean_lhs = "tcp.broadcast_to"(%unclean_lhs, %broadcasted0_s)
%clean_rhs = "tcp.broadcast_to"(%unclean_lhs, %broadcasted0_s)
%clean_rhs2 = "tcp.broadcast_to"(%unclean_rhs2, %broadcasted_s)
%result0 = "tcp.add"(%clean_lhs, %clean_rhs) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%clean_result0 = "tcp.broadcast_to"(%result0, %broadcasted_s)
%result = "tcp.add"(%clean_result0, %clean_rhs2) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
Now we have all the tcp.add ops in their own nice little SSA use-def subgraph, except for the intervening tcp.broadcast_to %clean_result0 = "tcp.broadcast_to"(%result0, %broadcasted_s)
. That one is seems nontrivial to remove. For example, suppose that dynamically the types are:
%unclean_lhs: tensor<1x1>
%unclean_rhs: tensor<1x1>
%unclean_rhs2: tensor<3x3>
Then given the lowering as it stands now, the first tcp.add will return tensor<1x1> and the intervening tcp.broadcast_to will in fact need to dynamically perform a 1x1 → 3x3 broadcast. However, it could also be:
%unclean_lhs: tensor<3x3>
%unclean_rhs: tensor<3x3>
%unclean_rhs2: tensor<3x3>
Then there is no broadcast needed.
@nicolasvasilache how would it be best to model in linalg this possibility of dynamically different broadcasting behavior depending on whether any of the inputs have dimensions which dynamically are of size 1?