Tosa.reduce_sum with tuple axis

Hello. Tosa dialect’s tosa.reduce_sum doesn’t support tuple axis.(only support axis is a int value). What should i do if my target is to impl like “tosa.reduce_sum”(%arg0) {axis =(0,2)} (when %arg0 is a tensor<2x3x4xf32>) ?

I know one answer can be invoke two reduce_sum, like:

%0 = "tosa.reduce_sum"(%arg0) {axis = 0}
%1 = "tosa.reduce_sum"(%0) {axis = 2}

But two reduce_sum ops may lead a poor performance comparing with one op?

One of the goals with TOSA is to have the operators be as reduced as possible to enable simple implementations. As you described, a reduction across multiple axes can be decomposed into multiple reduce operators, so we decided to require a single axis.
If a backend can do a reduction across multiple axes, it should be reasonably straightforward to pattern match the multiple reduce operators and use the optimized version.

1 Like

I got it, thx.