[RFC][arith] Add extended multiplication ops

Hi folks,

I propose to add 2 new ops to the arith dialect:

  • %low, %high = arith.mului_extended %a, %b : iN – extended unsigned multiplication
  • %low, %high = arith.mulsi_extended %a, %b : iN – extended signed multiplication

These ops return two results: the low and high halfs, which together form the full ‘lossless’ multiplication result. We proposed and implemented something similar for addition: [RFC] Add integer add with carry op to arith, with the main difference being that arith.addui_carry returns an overflow bit instead of a full integer for the second result.

The new ops map to OpUMulExtended and OpSMulExtended in SPIR-V. There’s not direct counterpart in LLVM: llvm.umul.with.overflow.* and its signed equivalent return a single overflow bit. Returning a single bit is less useful because it does not preserve the mathematical meaning of multiplication: iN * iN -> i2N. In addition to SPIR-V, there are many ISAs that return a full result, e.g.:

An alternative to adding extended multiplication intrinsics would be to sign/zero extend iN operands and perform multiplication over i2N types, which would have to be then peephole-optimized in each of the dialects lower down the lowering chain from arith, which may be problematic because of the introduction of wider integer types.

My primary and immediate use cases for this are:

  1. The Wide Integer Emulation pass, where using the extended multiplication ops would have noticeable saving both in terms of the resulting code size and complexity of the implementation.
  2. Efficient lowering of tosa.apply_scale, which requires extended multiplication.

-Jakub

cc: @antiagainst @Mogball @_sean_silva @ftynse @benvanik

2 Likes

+1. Thanks Jakub for the RFC! Having this to enable a more principled lowering path makes sense to me.

Thanks for the proposal, Jakub! Just passing by and thought this information could be useful: RISC-V (and more architectures, lately) also has “widening” instructions where you can represent an i2N <- iN * 2N with a single op (see vwmul instruction, for example). It also has the respective counterparts for additions, fmadds and reductions instruction. Given that splitting vs not splitting seems to be becoming a target-specific decision, I wonder if at MLIR::Arith level we should represent the computation as a single operation and defer the splitting decision down the road. For example:

%a = arith.muli %b, %c : iN -> i2N

Not sure if this approach would cause any problems but I think it aligns well with the MLIR progressive lowering philosophy.

Thanks for the context, @dcaballe. I think I can understand the appeal of returning i2N, but I also see a few issues:

  1. This would create asymmetry compared to the existing add with overflow ops: "addui_carry"(%a : iN, %b : iN) -> iN, i1, which I think would have to follow and become either (iN, iN) -> i(N+1) or (iN, iN) -> i2N. In this case it would not directly map to neither LLVM nor SPIR-V.
  2. It doesn’t map directly to either of the main lowering targets, LLVM and SPIR-V.
  3. In the wide integer emulation pass specifically, this would be problematic to handle because we would have to emulate it with a series of plain multiplications (iN, iN) -> iN first. This doesn’t help much with my goal of emulating (i2N, i2N) -> i2N with iN ops I had in mind, and I’d probably have to rely on peephole optimizations lower down the lowering chain to eventually replace it with Op*MulExtended.

None of these would be deal breakers, but they make this definition less practical for my intended use cases.

Nice to see support for this, and looking forward to the emulation support.

What is the plan for lowering to LLVM?

This one only needs “mulhi” operations. Is the idea that we will pattern-match this case by looking for no uses of “%low”?

Also, for this use case, it is important for the signed multiplication to include a “doubling” in the op semantics. This is because you are really doing. (i1 sign, i{N-1}) * (i1 sign, i{N-1}) → (i1 sign, i{2*(N-1) + 1}) multiplication. Two sign bits on the input turn into one sign bit on the output, so you need to fill the extra bit somehow. In order for %high to correctly represent the mathematical result of the fixed-point multiplication, the 2*(N-1)-bit result of the significand multiplication needs to be shifted left by 1 to be adjacent to the sign bit. E.g. ARM SQDMULH (use in Ruy). There is also a variant that rounds instead of truncates (SQRDMULH) but we do not use that because ironically it leads to worse precision in this use case (long explanation – this is what the double_round attribute on tosa.apply_scale refers to). Also this use case requires saturation to match the intended asm sequence that we want in practice (not sure where in the pipeline we will introduce the saturation semantics though).

I was under the impression that in llvm it’s generally fine to introduce wide integer types and rely on the pass pipeline to eventually legalize them if necessary**. If that’s the case, we could zero/sign extend operands and perform a regular multiplication over i2N types, and extract the result(s). If not, we can perform a few iN multiplications instead, similar to how we handle multiplication in the wide integer emulation pass today.

The exact lowering path I had in mind is here: llvm-project/TosaToArith.cpp at 57dc4a8cab1257c5412471139ef4b6d6060997c9 · llvm/llvm-project · GitHub.

In general, I think it should be easier to tell that one of the results is unused compared to telling that low bits of a single result are ignored.


** This is not the case in SPIR-V BTW, where non-32-bit integer types require specific capabilities.

Could you instead introduce an intrinsic in LLVM? The intent would be clear. It gives more control to LLVM how to handle the operation.

We can do that if the others who have stakes in the LLVM side of things are onboard. However, I’d like to avoid arith and llvm being codependent like I argued here: [RFC] Define precise arith semantics. I think it would be unfortunate if the lack of a perfectly matching llvm intrinsic blocked this or future changes to the arith dialect.

1 Like

FYI, I’ve just submitted the first patch for review: ⚙ D139688 [mlir][arith] Define mului_extended op