[RFC] Clarifying `IndexType`'s role in `arith` rewrites

Hello,

We recently found a couple of places where IndexType’s bit width plays a role in transformations.

Specifically,

// index_cast(index_cast(x)) -> x, if dstType == srcType.
def IndexCastOfIndexCast :
    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)),
        (replaceWithValue $x),
        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
// index_castui(index_castui(x)) -> x, if dstType == srcType.
def IndexCastUIOfIndexCastUI :
    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
        (replaceWithValue $x),
        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
  // index_cast(constant) -> constant
  unsigned resultBitwidth = 64; // Default for index integer attributes.
OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
  // index_castui(constant) -> constant
  unsigned resultBitwidth = 64; // Default for index integer attributes.

IndexType’s bit width is not well defined

We understand that IndexType has had different meanings for different people:

The index type means “platform specific integer width”

index is “an integer the size of a pointer in the default address space”

index as an “infinite precision” integer

To complement this, index was originally called affineint, was usable within “MLFunc” (first-class affine function) and had a vague semantics of a “mathematical integer suitable for affine reasoning”. It was then retrofit to be a sort of intptr_t/size_t mix as we moved away from first-class affine.

Unsoundness of rewrites based on IndexType’s semantics above

Going back to the transformations mentioned above, under all interpretations listed here the canonicalization and folding patterns are unsound.

  • The folding patterns assume that the size of the index is 64

  • The canonicalization patterns are unsound when the one casts to type narrower than index and then back to index. Since the first cast is narrower, information can be lost during the cast. E.g., for index of infinite precision or i16

%0 = arith.constant 0x8000 : index
%1 = arith.index_cast %0 : index to i8
%2 = arith.index_cast %1 : i8 to index
// %2 = 0

Proposal for exact keyword in arith.index_cast and arith.index_castui.

We propose to reformulate the above unsound transformations into sound transformations by introducing the exact keyword. The exact keyword’s semantics are the following:

If the exact attribute is present, it is assumed that the index type width is such that the conversion does not lose information. When this assumption is violated, the result is poison.

The rewrites above can then be reformulated as:

// index_cast(index_cast(x, exact)) -> x, if dstType == srcType.
// The inner exact guarantees the iN -> index conversion is lossless,
// so the roundtrip through index preserves the value.
def IndexCastOfIndexCast :
    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
        (replaceWithValue $x),
        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
         (Constraint<CPred<"(bool)$0">> $exact1)]>;

// index_castui(index_castui(x, exact)) -> x, if dstType == srcType.
// The inner exact guarantees the iN -> index conversion is lossless,
// so the roundtrip through index preserves the value.
def IndexCastUIOfIndexCastUI :
    Pat<(Arith_IndexCastUIOp:$res
          (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
        (replaceWithValue $x),
        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
         (Constraint<CPred<"static_cast<bool>($0)">> $exact1)]>;

OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
  // index_cast(constant, exact) -> constant

and

OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
  // index_castui(constant, exact) -> constant

For which some work was done in llvm/llvm-project#183395 and some discussion happened on llvm/llvm-project#184631

How to set exact?

Now the question is, when will these transformations execute if nobody is setting the exact keyword? We propose that there’s a pass available (e.g., arith-declare-index-bitwidth) for downstream users which will set exact. It will:

  • It’ll optionally update DLTI on the operation it’s called on (if it isn’t already set)

  • and then walk all the index_cast ops

  • for index_castui and index_cast if the target type is larger than or equal to the source type, then we set the exact flag.

  • for index_castui and index_cast if the target type is smaller than the source type, one needs to statically determine that the operand value fits in the target type, which could be done with a range analysis.

The DLTI dialect provides the DataLayoutSpecAttr and DataLayoutEntryAttr which allows the size of index to be explicitly set in the IR. For example:

module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 64 : i32>> } {
  // .. a module where index is 64 bits wide.
}

This let’s users control:

  • when the folders and canonicalizations will run

  • make sure that the folders and canonicalizations are safe

  • without the pass being run, the folders and canonicalizations will not be executed as the operations would lack the exact flag.

We also discussed the possibility of setting a lower bound and an upper bound for index’s bitwidth. This would allow compilers to apply these transformations even if the target is not precisely known. This matches the two conflicting semantics nicely:

  • If downstream projects consider index as an integer of infinite precision, one can set the index’s bitwidth arbitrarily high. (lower bound, and upper bound are the same)

  • for index_castui and index_cast if the target type is larger than or equal to the lower bound and the source type is equal or smaller to the lower bound, then we set the exact flag.

  • for index_castui and index_cast if the target type is smaller than the upper bound and the source type is larger than or equal to the upper bound (i.e., the value is being narrowed), one needs to statically determine that the operand value fits in the lower bound (i.e., it fits in the smallest bit width type) one can set the exact flag.

Please let me know your thoughts and comments. Thanks!

Can we just always have the semantics currently proposed for exact? Alternatively, have it be default and have an explicit inexact?

Can we just always have the semantics currently proposed for exact?
Alternatively, have it be default and have an explicit inexact?

I think something that is missing from the RFC is that the exact flag would also enable lowering index_cast and index_castui to truncation with the corresponding no-wrap flags. If we changed the semantics of index_cast and index_castui , one would technically be able to always lower to truncation with no-wrap flags which is not always safe. We would definitely need an inexact flag for the cases where truncated bits are non-zero or different from the MSB of the truncated result.

I’d prefer the exact flag over inexact as it mirrors LLVM-IR and there is already an exact unit attribute in the arith dialect for other operations.

I like this idea. I mostly need index where it isn’t 64-bit wide, so not having such an assumption is good. So this says: if we don’t know its safe, ~nothing happens in the device independent part really, if a pass (could be inserted anywhere, could have been assumption for compiler/lowering pipeline) asserts its safe it ungates these optimizations/canonicalizations.

I’m a bit confused about this, because if you fix the width of index, then how does it differ from i64 here?
As soon as the index size get fixed, I would imagine we should be able to lower all of the index operations to arith on regular integers?

A lot of dialects, such as memref, operate exclusively on index-typed values, using them like Rust’s isize/usize, so no, we can’t just do that but of search-replace

I’ll also note this this’ll require a breaking change that should’ve probably been made anyway - the DLTI analysis currently fall back to index==i64 when nothing’s been specified, but that should be replaced by having those queries return an optional width instead.

Assuming I interpret you correctly here (I don’t if your auto-correct went wrong or what?): you’re saying that index can’t be replaced by an integer in some dialect at the moment, is that correct?

If so, then that’s just describing some current implementation details though: we could make it so that it is always possible and build an op interface to support the materialization of index into a sized-integer type.

Yeah, that was a typo caused by scribbling down thoughts on mobile early in the morning.

Now if we want to make memref, affine, tensor, etc all also have support for iN in all the places that they take index … I’d support that (especially since it’ll allow for memory spaces to specify the width of index they operate on and we could, for example, materialize the fact that shared memory is indexed by 32 bits on GPUs but global memory is indexed by 64 bits) … I’d be in favor of that, but that’s a bigger and potentially more controversial RFC, no?