Hello,
I have a question regarding the supported implementations for the Cast operation. According to the specification cast is not defined for any float->bool combinations(for good reason I suppose).
While in the Tosa dialect trying a cast from float->bool is valid and gets lowered to linalg:
// -----// IR Dump After TosaValidation (tosa-validate) ('builtin.module' operation) //----- //
#loc = loc(unknown)
module {
func.func @main(%arg0: tensor<8x2xf16> loc(unknown)) -> tensor<*xi1> {
%0 = tosa.cast %arg0 : (tensor<8x2xf16>) -> tensor<8x2xi1> loc(#loc1)
%cast = tensor.cast %0 : tensor<8x2xi1> to tensor<*xi1> loc(#loc1)
return %cast : tensor<*xi1> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("")
// -----// IR Dump After TosaToLinalg (tosa-to-linalg) ('func.func' operation: @main) //----- //
#loc = loc(unknown)
#loc1 = loc("")
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @main(%arg0: tensor<8x2xf16> loc(unknown)) -> tensor<*xi1> {
%0 = tensor.empty() : tensor<8x2xi1> loc(#loc1)
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<8x2xf16>) outs(%0 : tensor<8x2xi1>) {
^bb0(%in: f16 loc(unknown), %out: i1 loc("")):
%cst = arith.constant 0.000000e+00 : f16 loc(#loc1)
%2 = arith.cmpf une, %in, %cst : f16 loc(#loc1)
linalg.yield %2 : i1 loc(#loc1)
} -> tensor<8x2xi1> loc(#loc1)
%cast = tensor.cast %1 : tensor<8x2xi1> to tensor<*xi1> loc(#loc1)
return %cast : tensor<*xi1> loc(#loc)
} loc(#loc)
} loc(#loc)
Do you think the dialect should detect that this isn’t a valid combination?