[RFC] Memref bitcasting

I need to cast memref to memref with other element type type without copying the data (in my specific case between memref<?xf64> and memref<?xvector<2xi32>> for gpu double emulation pass), but memref dialect doesn’t seems to have an op for this. On the other hand arith.bitcast supports memref bitcasting for some reason (which I found mostly accidentially), but it doesn’t support memrefs of vectors. And also, for the scalars arith.bitcast doesn’t support casting between vectors and non-vectors.

So proposal:

  1. Add memref.bitcast op to memref dialect, which can change memref element type, but shape, layout, memory space and element size must remain the same (and we can also put view interface on this op).

  2. Remove memref support from arith.bitcast but relax casting rules to allow casts between vectors and non-vectors.

Sounds good to me. I think it would be better to instead name it memref.elt_cast or memref.elt_bitcast.

+1

Please be sure to be very restrictive in what is allowed and ideally connect it to DLT to avoid alignment and padding surprises.

FYI, I proposed this before but didn’t reach consensus: [RFC][arith] Should we support scalar <--> vector `arith.bitcast`s? . The suggested alternative was to introduce a helper function that emit appropriate sequence of ops to cast between scalars and vectors.

This is not my discussion, but please do not you use bitcast or cast. The C++ bit_cast is a very powerful tool. Cast sounds C-ish. Maybe convert_to instead.
https://en.cppreference.com/w/cpp/numeric/bit_cast

Regarding using vector construction/deconstruction instead of direct bitcast, it can also have performance implications. Instead of directly emitting llvm/spirv bitcast we will have to rely on underlying compiler to optimize code in no-op case and this can be much less reliable and harder for predict than cases when compiler had to generate additional code for bitcast under the hood IMO.

I didn’t quite get this C++ argument, but I think similar operation is already called bitcast in a various places, including llvm and spirv

It is totally up to you! However, cast and checked conversion are two different concepts for me.

Initial version of memref.elt_bitcast ⚙ D137472 [mlir][memref] Add `memref.elt_bitcast` op

^^^^
ping

I’m concerned that this opens the box for all sorts of weird aliasing and data layout problems. So far, mermef almost always ensured a sort of strict aliasing, except for i8 memrefs passed to memref.view. Allowing element cast will relax that. Memrefs with dense contiguous storage are the minority of cases, and that is a precondition for reinterpreting them as vectors. LLVM has all sorts of non-trivial alignment/size behavior by default when vector sizes are not power-of-two. How can we be sure that we are protected from these?

IMO, memref is the wrong abstraction for this work, which rather needs some sort of untyped pointer like !llvm.ptr or, suboptimally, memref<?xi8> that is memref.viewed to different types.

Casting src->i8*->dst will work too, but it isn’t possible with current memref dialect either. And a single op have an advantage that we can at least validate shapes and layout map are compatible. For the strange vector sizes, I was planning to check them during memref-to-llvm lowering using DataLayout and reject lowering if size didn’t match.

But the question of strict aliasing is really good, it will work in my case, because I’m always read/write one type on CPU and the other one on GPU, so aliasing rules cannot be enforced. But if users try to read/write both original and bitcasted memref in the same function they will certanly break aliasing rules (at least C++ ones).

The src->i8* part is intentionally impossible, hence my concern. This is introducing a reinterpret_cast at a distance (one has to allocate the right amount of memory as a vector of i8, and then elt_cast through it).

Did not look at the impl but that is why I mentioned.

Yeah this should be scoped out and done on pointers, not memref, if needed.

So what the best way forward? I’m fine with either dedicated elt_bitcast op or intermediate pointer conversion for my needs.