Eliding casts

I have an ir snippet that looks like:

%casted = tensor_cast %input : tensor<...> to tensor<...>
%element = extract_element %casted[...]

extract_element isn’t sensitive to the static type of its input tensor, so eliding the cast is fine.

One way to do this would be to walk all uses of a tensor_cast and call a function like this with the source type:

// Return true if `operand` can be changed to type `type`.
bool canChangeToType(OpOperand operand, Type type) {

If all uses return true, then the cast can be folded.

Another approach is to change the result type of %input’s producer, which needs a similar predicate on OpResult.

Thoughts on how to approach this?

I didn’t get your question. Can you complete the ... in your snippet. extract_element of course gives you a value of the elemental type of the tensor, and so it is sensitive to the static type of the input tensor. Did you mean shape instead of type and is your goal to eliminate the cast if extract_element is the only use?

tensor_cast can only change the shape, not the element type (both the shape and element type are part of the “type” as I’m using that term here). Sorry I didn’t make that clearer. Consider this rewritten snippet:

%c0 = constant 0
%casted = tensor_cast %input : tensor<?xf32> to tensor<10xf32>
%element = extract_element %casted[%c0]

The observation is that extract_element is not sensitive to the static shape of the input.

This situation is specific to types like tensor where multiple static types can refer to the same dynamic type. If a cast is merely changing a static type (as tensor_cast does) without implying any runtime behavior, then, as long as all uses can accept that static type, it is safe to remove the cast.

Seems like a canonicalization on the extract_element maybe?

EDIT: ignore this post.

That is of limited applicability. Consider if there are two extract_element ops using the same tensor (or another random use by an operator that has the same property that extract_element does).

Oh, sorry, I spoke too soon. I see what you mean. I think that would work somewhat well. Thought it is a bit sad needing to write that for every operator that might use tensor_cast.