Several discussions here have touched upon the topic of dynamic broadcasts. Wanted to get some feedback on whether the prevelant broadcasting semantics is adding unnecessary complications to the front-end and codegen. Specifically, I am talking about the broadcast semantics where if one of the dimensions is of size 1
, then that dimension is broadcasted. For example
# Snippet 1:
>>> c = np.add(a, b);
>>> a.shape
(7, 1, 6)
>>> b.shape
(7,5,6)
>>> c.shape
(7, 5, 6)
This is effectively c[i][j][k] = a[i][0][k] + b[i][j][k]
.
One way of expressing this in XLA (which for the most part makes broadcasts explicit)
// Snippet 2:
%t = xla_hlo.broadcast_in_dim %a
{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} :
tensor<?x1x?xf32> -> tensor<?x?x?xf32>
%c = xla_hlo.add %t, %b : (
tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
In the above snippet I left all dimension, except the broadcasting dimension dynamic. If that dimension were dynamic as well, i.e. tensor<?x?x?xf32>
instead of tensor<?x1x?xf32>
, then this falls into âdynamic broadcastâ category. This adds a lot more complexity to the code-generation since you effectively need to check for whether the size is 1 dynamically which changes the semantics of the operation.
An alternative would be to move away from the special status of size 1
dims, and have a more pure form of broadcasting which doesnât rely on it. So the above computation could be re-written to
// Snippet 3:
%t = xla_hlo.broadcast %a ... : tensor<?x?xf32> -> tensor<?x?x?xf32>
%c = xla_hlo.add %t, %b : (
tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
I am using xla_hlo.broadcast
here, but its specification actually doesnt allow this kind of broadcast. But it should be possible to generalize it to express such semantics. What you would need is a map from the original dimensions to the dimensions in the result that it is broadcasted to. So the above broadcast would be captured using these maps { 0 -> {0, 1}, 1 -> 2}
which says both dims 0
, 1
of the result get their value from dim 0
of the operand, and that dim 2
of the result gets the value from dim 1
of the operand.
This specification is immune to the dynamic broadcasting behavior that Snippet 2 suffered from. Even when expressed using Linalg Snippet 3 can be expressed as
#accesses =
[affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
#trait = { ...,
indexing_maps = #accesses, ...
}
%c = linalg.generic #trait %a, %b {
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
%2 = addf %arg2, %arg3:
linalg.yield %2
} :(tensor<?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
Apart from the advantage of not having the dynamic broadcasting problem, the above Linalg specification is IMO more canonical. It plays well with Linalgâs approach of determining loop bounds based on shapes of the operands (not going into more detail here). Indeed, I spent a few days prototyping some canonicalizations and folding at Linalg level to undo the presence of such size-1 dims at Linalg level since it leads to some problems with loop range inference. (Looking back at the original computation and seeing why this happened in the first place led me to the aspects highlighted in this post).
My take away from this is that having the special semantics for size-1 dims is problematic. It seems to be a complications introduced by abstractions than anything fundamental to the computational needs of the end-user. From my experience end-users typically end up doing a lot of reshapes to get the shape in the right form to do the broadcasting. (This also leads to other problems of having to deal with reshapes in codegen, which deserves a whole other post). If this is not the case, Iâd happy to know more uses cases which do not allow the usage of broadcast semantics similar to Snippet 3.
I have two questions here
- Can we rely on the dialects/conversion at the very top of the stack (like TCP/TCF) to take inputs that might have such size-1-dim-based broadcast semantics and âcanonicalizeâ them away so that you use the broadcast operations similar to Snippet 3. Though i suspect in dynamic broadcasting case there is very little that can be done.
- I think it actually would benefit end-users as well to not have to use size-1-dim-based-broadcasting. Would like to know more cases where it cannot be expressed using a purer broadcasting form.