Verifiers and conditionally legal operations

I have a use case where I would like to take a (possibly unranked) tensor and branch on its rank, in order to handle rank-0 specially. The natural way to do this would be to check the rank and use scf.if as the branching mechanism. The type can be safely refined in the body of the if-statement using tensor.cast.

This works fine in the case where the given tensor (arg) is in fact unranked:

func.func @foo(%arg : tensor<*xf32>) {
	%c0 = arith.constant 0 : index
	%rank = tensor.rank %arg : tensor<*xf32>
	%0 = arith.cmpi eq, %rank, %c0 : index
	%1 = scf.if %0 -> (i1) {
	  // Cast arg to rank-0 and use it
	  %3 = tensor.cast %arg : tensor<*xf32> to tensor<f32>
	  do_something_with_rank_zero_arg(%3)
	} else {
	  // Do something with rank > 0 tensor
	}
}

However, if the arg is not unranked, but statically known to have a fixed shape like tensor<2xf32>, this same process constructs code that fails to verify:

func.func @foo(%arg : tensor<2xf32>) {
	%c0 = arith.constant 0 : index
	%rank = tensor.rank %arg : tensor<2xf32>
	%0 = arith.cmpi eq, %rank, %c0 : index
	%1 = scf.if %0 -> (i1) {
	  // Cast arg to rank-0 and use it
	  %3 = tensor.cast %arg : tensor<2xf32> to tensor<f32> // error: 'tensor.cast' op operand type 'tensor<2xf32>' and result type 'tensor<f32>' are cast incompatible
	  do_something_with_rank_zero_arg(%3)
	} else {
	  // Do something with rank > 0 tensor
	}
}

There potential ways I’ve found to make this work:

  1. Manually constant fold away the scf.if statement when the rank is statically known. That works for simple cases, but may just push the verification failures to later in the pipeline (if some types are ever refined) and complicates the lowering process.
  2. Use a different cast operator which lacks a similar verifier check, in this case tosa.cast works

I’ve not encountered them, but I suspect there are other potential issues, such as the fact that tensor.cast is ConditionallySpeculatable, and so can be hoisted out of the guarding branches.

Is there a known way to write lowerings like this? It would be nice to just write the code I want during lowering and let the dead code fold away without worrying about accidentally triggering verifiers.

You could write it by casting the input tensor to unranked explicitly in the if branch maybe? It’ll be a no-op when the tensor is actually unranked, but make the code legal otherwise.

That does indeed satisfy the verifiers. My initial concern was that such a cast sequence was likely to be eliminated by some canonicalizations and we would just end up back in the same situation, just a little later down the line. It looks like the canonicalization for chained casts has explicit consideration for cases like this, though.