I have a use case where I would like to take a (possibly unranked) tensor and branch on its rank, in order to handle rank-0 specially. The natural way to do this would be to check the rank and use scf.if
as the branching mechanism. The type can be safely refined in the body of the if-statement using tensor.cast
.
This works fine in the case where the given tensor (arg
) is in fact unranked:
func.func @foo(%arg : tensor<*xf32>) {
%c0 = arith.constant 0 : index
%rank = tensor.rank %arg : tensor<*xf32>
%0 = arith.cmpi eq, %rank, %c0 : index
%1 = scf.if %0 -> (i1) {
// Cast arg to rank-0 and use it
%3 = tensor.cast %arg : tensor<*xf32> to tensor<f32>
do_something_with_rank_zero_arg(%3)
} else {
// Do something with rank > 0 tensor
}
}
However, if the arg
is not unranked, but statically known to have a fixed shape like tensor<2xf32>
, this same process constructs code that fails to verify:
func.func @foo(%arg : tensor<2xf32>) {
%c0 = arith.constant 0 : index
%rank = tensor.rank %arg : tensor<2xf32>
%0 = arith.cmpi eq, %rank, %c0 : index
%1 = scf.if %0 -> (i1) {
// Cast arg to rank-0 and use it
%3 = tensor.cast %arg : tensor<2xf32> to tensor<f32> // error: 'tensor.cast' op operand type 'tensor<2xf32>' and result type 'tensor<f32>' are cast incompatible
do_something_with_rank_zero_arg(%3)
} else {
// Do something with rank > 0 tensor
}
}
There potential ways I’ve found to make this work:
- Manually constant fold away the
scf.if
statement when the rank is statically known. That works for simple cases, but may just push the verification failures to later in the pipeline (if some types are ever refined) and complicates the lowering process. - Use a different cast operator which lacks a similar verifier check, in this case
tosa.cast
works
I’ve not encountered them, but I suspect there are other potential issues, such as the fact that tensor.cast
is ConditionallySpeculatable
, and so can be hoisted out of the guarding branches.
Is there a known way to write lowerings like this? It would be nice to just write the code I want during lowering and let the dead code fold away without worrying about accidentally triggering verifiers.