[RFC] Dialect type cast op


In my local GPU convolution prototype, I invented such an op, std.stdcast:

  %vec = ... : vector<2xf16>
  %llvm_vec = std.stdcast %vec to !llvm<"<2 x half>">
  %res = nvvm.mma.sync ..., %llvm_vec, ...

or vice versa

  %llvm_tid = nvvm.read.ptx.sreg.tid.x  : !llvm.i32
  %tid = std.stdcast %llvm_tid : i32
  %thread_id = index_cast %1 : i32 to index
  // use %thread_id

It casts between a dialect type and standard type. The lowering is just asking the type converter “whether the cast is a no-op after lowering”. If it’s not, lowering can report an error.

My primary motivation is to stitch different dialects in the same CFG (the first example above). Today, one has to build suitable abstractions in high-level dialects like Standard, Vector, etc. then lower them to the low-level dialects like NVVM. If I already have an NVVM op to use in mind, it’s not always trivial to build such high-evel abstractions, and arguably it’s not even necessary.

std.stdcast also enables gradual lowering (the second example above). Even though everything in MLIR is better than LLVM SelectionDAG, but I can still see StandardToLLVM lowering being bulky in MemRefs. Ideally one might lower all memrefs first, then the rest. Or vice versa.

I’m really curious about more use cases from the community. WDYT?

I somehow remember having a long discussion about this internally ~ 1 year ago :), I was trying to propose exactly this for the same reasons, and Chris had good arguments against it, but I don’t remember them all unfortunately (we went through examples where the lowering would explode one values into multiple for example), I’m sure @ftynse and @River707 should be able to fill the gap.

Ultimately we converged toward the fact that these cast op have to be introduced on a case-by-case in each dialect and managed specifically. I’d be happy to help revisiting this though!

You overestimate my memory capacity :slight_smile: We can try and ping @clattner for arguments.

We indeed had this discussion in the context of progressive lowering, when working on the tutorial for EuroLLVM specifically. The specific problem was the mix of dialect conversion passes that change types and pass ordering. We had Toy->LLVM conversion pattern for one operation, and Toy->Std + Std->LLVM for all the other operations. We could not order them because types would clash. Reducing this to the following:

%0 = "op1"() : () -> !dialectA.type
%1 = "op2"(%0) : (!dialectA.type) -> !dialectA.type

op1 cannot be converted to another op producing !dialectB.type unless (a) op2 accepts that type or (b) op2 is simultaneously lowered to another op accepting that type.

Mehdi’s principal argument was that we want a system where dialects can co-exist at all levels. Having a cast would help. Otherwise, we need to perform the conversion in one shot to avoid the problem. I was arguing, according to the infra plans at the time, that we were building big pattern rewriting machinery anyway so we should just register all the patterns. We agreed that even if it was feasible to put all the patterns together, it would be detrimental to separability of abstractions (dialects) .

We first had a band-aid solution where we disabled the verifier between passes. We actually ended up implementing the cast operation (it never existed as a class IIRC, was created with a dummy name through OperationState) and using it under the hood by adding a low-benefit pattern that would match any operation in the LLVM dialect that would consume a value of non-LLVM type and injecting the cast between them. Conversely, the cast operation would have been removed if it were converting the result of an LLVM dialect operation. I actually don’t know if this feature is still in the Toy code after the update. Another option we considered was to make every conversion pattern unconditionally emit cases for all arguments and results, and then canonicalize away A->B->A casts. This is part of the explosion Mehdi refers to because in case of A->B->C conversion, we would quickly end up with more cast operations that payload operations.

In the hindsight, the problem is not with dialects but with dialect-defined types, which create the actual compatibility boundary. Dialects operations compose easily as long as they operate on the same type, and type system conversions are less frequent than dialect conversions, decreasing the potential for explosion although without eliminating it completely. We also had a significantly less configurable conversion infra a year ago.

My main argument against a generic cast operation was, and still is, that it could not have specific semantics. What is the semantics of stdcast !tensor<*xf32> to !llvm.i64 ? Or that of stdcast !llvm<"{i64*, double}"> to !linalg.range ? It was and remains too “magic”. Pushing this to the limit, allowing to cast from anything to anything else lets you essentially ignore types everywhere, making the IR untyped.

The semantics argument resonates with Chris’ reasoning about generic std.return and std.yield in yesterday’s call.

Chris argued for having dialect-specific casts IIRC, i.e. having a cast from !dialectA.type and !dialectB.type that could be aware of the specifics of both dialects and only support sound conversions. We had two technical concerns with this. (1) it is unclear where this operation should live, in dialectA, in dialectB, in a separate glueAB dialect? (2) We may end up having N^2 cast operations, potentially each living in its own “glue” dialect, which would be another source of complexity explosion.

Well, it looks like you haven’t overestimated memory capacity after all :slight_smile:

History aside, we already face the problem with in the GPU conversion as evidenced by this TODO https://github.com/llvm/llvm-project/blob/master/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp#L93.

I would consider introducing llvm.mlir.cast that specifically represents the cast from (supported) standard types to LLVM dialect types, and back. The verifier of this operation would enforce the same rules as the type conversion. (I have been thinking about table-generating type conversions as well, at which point we may generate the verifier as well. It can live within the LLVM dialect because it is already dependent on standard types for attributes.

Orthogonally from that, I’d like to get rid of the one-shot convert-linalg-to-llvm pass, which would require doing something similar to convert !linalg.range, but I need to think more and align with @nicolasvasilache about that.

I believe Chris’s argument against was that the op should have runtime semantics and therefore the op would need to know how to do the conversion, but there is no general op that could convert any arbitrary dialect type to another dialect type.

In the direction of lowering? You need a dialect cast op, and you need a pass that inserts that cast that knows that this dialect cast op is able to handle the conversion.

That just means don’t invent dialectA.float if you can use std.float or inflict a lot of pain on folks :wink:

Yes this. The issue with casts is that they have operational semantics that needs to be reasoned about. For example, a conversion from int to float does a lot of bit shuffling. A conversion from int to string is another ball of wax because strings are complicated :slight_smile:

An std.cast sort of operation seems like a “do what I mean” operation that would have to be specified somehow. In the case of the std type system, such a thing doesn’t even make sense, because (for example) when converting from i32 to f64 you need to know whether you’re sign or zero extending.


Yes, I think having something like this in the llvm dialect would make perfect sense. I’d avoid the word “cast” here though, perhaps something like convert, remap, or even bitcast or noopcast to make it clear that there is no representational change allowed.

I can see quite a few variants poping up in the discussion, in terms of:

Which types are supported:

  • any dialect <-> any dialect
  • standard <-> any dialect
  • standard <-> LLVM dialect

, and runtime behavior:

  • Changes bit representation
  • no-op

Let me try to summarize discussed variants:

a) Any dialect <-> any dialect, runtime no-op or not: No no no, no. It’s more of a backdoor than a meaningful op, since no semantics can be specified.

b) Standard <-> any dialect, runtime no-op (what I proposed):

  • It doesn’t create a cast explosion.
  • The semantics is indeed still up to individual dialects (type converter), but we know that it has to be an no-op at runtime.

c) Standard <-> LLVM dialect, runtime no-op: No explosion, semantics are clear.

Personally I think both b) and c) are viable, but I haven’t had use cases outside of what c) can handle. It looks like c) is where we converge to.

Converting index will be interesting in this sense. There’s no equivalent in the LLVM dialect, and we convert it based on the DataLayout that we don’t even see in the MLIR. Should it be plainly forbidden (we can require an index_cast to a properly sized integer type)? Otherwise, it’s not necessarily a no-op…

This means the lowering-target dialect will have to know about all dialects that lower to it, including the out-of-tree ones. For example, FIR may want to lower to LLVM directly, but it does not even live in MLIR so there’s no way the LLVM dialect can know about it.

I am generally wary about making one dialect depend on another, or effectively impose the direction of conversion (e.g., opposing lowering to raising). As an infrastructure, we want to be less constraining.

I’ve seen use cases for “any → LLVM”, not for “std → any”. Starting with a practically necessary use case, (c), and possibly generalizing when we see other necessities is a good way to prevent over-design an unnecessary feature creep IMO.

I am looking at breaking the include the kitchen sink behavior of linalg-to-llvm.

I see DialectCastOp has provisioned for the memref case.

I’d like to relax the verifier of DialectCastOp to return success if one of the source/result type is memref and the other one is any of the possible LLVM lowerings of memref.


Is “any of the possible LLVM lowerings” due to the bare-ptr convention (and us not known which convention will be used) ?

Yes, this ambiguity is pointed to by @timshen as the reason memref is currently disallowed in DialectCastOp.

I am fine with it, hoping that we will eventually have aliasing info so that we can drop the bare pointer conversion entirely

yes in the fullness of time … but we will prob have a pointer type in std before that: [RFC] Remove MemRefType element type check? Or add pointer support to ‘std’ dialect? - #6 by mehdi_amini