Unnecessary `vmovapd` instructions generated; can you hint in favor of vfmadd231pd?

I’m developing a library that models simple loops and lowers them while making heave use of llvm intrinsics. When I use @llvm.fma.v8f64, it often likes to generate instructions other than vfmadd231pd, e.g. “vfmadd213pd”, which requires a vmovapd instruction to also be generated.
Here is an example:

L2208:
        vmovapd %zmm6, %zmm18
        vmovapd %zmm4, %zmm19
        vmovapd %zmm3, %zmm20
        vmovupd (%rsi,%rax,8), %zmm6
        vmovupd (%rsi,%rbp,8), %zmm4
        vmovupd (%rsi,%rdx,8), %zmm3
        vmovupd (%rsi,%rcx), %zmm21 {%k1} {z}
        vbroadcastsd    (%r9,%rdi,8), %zmm22
        vfmadd231pd     %zmm22, %zmm6, %zmm14 # zmm14 = (zmm6 * zmm22) + zmm14
        vfmadd231pd     %zmm22, %zmm4, %zmm17 # zmm17 = (zmm4 * zmm22) + zmm17
        vfmadd231pd     %zmm22, %zmm3, %zmm16 # zmm16 = (zmm3 * zmm22) + zmm16
        vfmadd231pd     %zmm22, %zmm21, %zmm15 # zmm15 = (zmm21 * zmm22) + zmm15
        vbroadcastsd    (%r11,%rdi,8), %zmm22
        vfmadd231pd     %zmm22, %zmm6, %zmm13 # zmm13 = (zmm6 * zmm22) + zmm13
        vfmadd231pd     %zmm22, %zmm4, %zmm12 # zmm12 = (zmm4 * zmm22) + zmm12
        vfmadd231pd     %zmm22, %zmm3, %zmm11 # zmm11 = (zmm3 * zmm22) + zmm11
        vfmadd231pd     %zmm22, %zmm21, %zmm10 # zmm10 = (zmm21 * zmm22) + zmm10
        vbroadcastsd    (%r8,%rdi,8), %zmm22
        vfmadd231pd     %zmm22, %zmm6, %zmm9 # zmm9 = (zmm6 * zmm22) + zmm9
        vfmadd231pd     %zmm22, %zmm4, %zmm8 # zmm8 = (zmm4 * zmm22) + zmm8
        vfmadd231pd     %zmm22, %zmm3, %zmm7 # zmm7 = (zmm3 * zmm22) + zmm7
        vfmadd231pd     %zmm22, %zmm21, %zmm5 # zmm5 = (zmm21 * zmm22) + zmm5
        vbroadcastsd    (%r13,%rdi,8), %zmm22
        vfmadd213pd     %zmm18, %zmm22, %zmm6 # zmm6 = (zmm22 * zmm6) + zmm18
        vfmadd213pd     %zmm19, %zmm22, %zmm4 # zmm4 = (zmm22 * zmm4) + zmm19
        vfmadd213pd     %zmm20, %zmm22, %zmm3 # zmm3 = (zmm22 * zmm3) + zmm20
        vfmadd231pd     %zmm22, %zmm21, %zmm2 # zmm2 = (zmm21 * zmm22) + zmm2
        incq    %rdi
        addq    %r10, %rsi
        cmpq    %rdi, %rbx
        jne     L2208

Notice that three of the last 4 vfmadd instructions are 213 instead of 231. This is problematic, because while the 231 variant accumulates a vector (increments a vector with the product of two others), the 213s don’t increment the vector, and instead overwrite one of the multiplied elements (specifically, they overwrite the three vmovupd (unmasked) instructions used in the four blocks of vfmadd.
That they don’t increment the accumulation vectors necessitates the vmovapds at the top of the block.

Adding the fast flag to the fmuladd instructions exacerbates this problem (especially on AVX2).

Example of the optimized llvm that generated it:

  %offsetptr.i1442 = getelementptr inbounds double, double* %typptr.i1784, i64 %222
  %ptr.i1443 = bitcast double* %offsetptr.i1442 to <8 x double>*
  %res.i1444 = load <8 x double>, <8 x double>* %ptr.i1443, align 8
  %223 = add i64 %221, %210
  %offsetptr.i1438 = getelementptr inbounds double, double* %typptr.i1784, i64 %223
  %ptr.i1439 = bitcast double* %offsetptr.i1438 to <8 x double>*
  %res.i1440 = load <8 x double>, <8 x double>* %ptr.i1439, align 8
  %224 = add i64 %221, %211
  %offsetptr.i1434 = getelementptr inbounds double, double* %typptr.i1784, i64 %224
  %ptr.i1435 = bitcast double* %offsetptr.i1434 to <8 x double>*
  %res.i1436 = load <8 x double>, <8 x double>* %ptr.i1435, align 8
  %225 = add i64 %221, %212
  %offsetptr.i1429 = getelementptr inbounds double, double* %typptr.i1784, i64 %225
  %ptr.i1430 = bitcast double* %offsetptr.i1429 to <8 x double>*
  %res.i1432 = call <8 x double> @llvm.masked.load.v8f64.p0v8f64(<8 x double>* nonnull %ptr.i1430, i32 8, <8 x i1> %213, <8 x double> zeroinitializer)
  %226 = add i64 %value_phi1622015, %214
  %ptr.i1426 = getelementptr inbounds double, double* %typptr.i1769, i64 %226
  %res.i1427 = load double, double* %ptr.i1426, align 8
  %ie.i1423 = insertelement <8 x double> undef, double %res.i1427, i32 0
  %v.i1424 = shufflevector <8 x double> %ie.i1423, <8 x double> undef, <8 x i32> zeroinitializer
  %res.i1422 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1424, <8 x double> %value_phi1632016)
  %res.i1419 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1424, <8 x double> %value_phi1652017)
  %res.i1416 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1424, <8 x double> %value_phi1672018)
  %res.i1413 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1424, <8 x double> %value_phi1692019)
  %227 = add i64 %value_phi1622015, %216
  %ptr.i1411 = getelementptr inbounds double, double* %typptr.i1769, i64 %227
  %res.i1412 = load double, double* %ptr.i1411, align 8
  %ie.i1408 = insertelement <8 x double> undef, double %res.i1412, i32 0
  %v.i1409 = shufflevector <8 x double> %ie.i1408, <8 x double> undef, <8 x i32> zeroinitializer
  %res.i1407 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1409, <8 x double> %value_phi1712020)
  %res.i1404 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1409, <8 x double> %value_phi1732021)
  %res.i1401 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1409, <8 x double> %value_phi1752022)
  %res.i1398 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1409, <8 x double> %value_phi1772023)
  %228 = add i64 %value_phi1622015, %218
  %ptr.i1396 = getelementptr inbounds double, double* %typptr.i1769, i64 %228
  %res.i1397 = load double, double* %ptr.i1396, align 8
  %ie.i1393 = insertelement <8 x double> undef, double %res.i1397, i32 0
  %v.i1394 = shufflevector <8 x double> %ie.i1393, <8 x double> undef, <8 x i32> zeroinitializer
  %res.i1392 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1394, <8 x double> %value_phi1792024)
  %res.i1389 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1394, <8 x double> %value_phi1812025)
  %res.i1386 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1394, <8 x double> %value_phi1832026)
  %res.i1383 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1394, <8 x double> %value_phi1852027)
  %229 = add i64 %value_phi1622015, %220
  %ptr.i1381 = getelementptr inbounds double, double* %typptr.i1769, i64 %229
  %res.i1382 = load double, double* %ptr.i1381, align 8
  %ie.i1378 = insertelement <8 x double> undef, double %res.i1382, i32 0
  %v.i1379 = shufflevector <8 x double> %ie.i1378, <8 x double> undef, <8 x i32> zeroinitializer
  %res.i1377 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1379, <8 x double> %value_phi1872028)
  %res.i1374 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1379, <8 x double> %value_phi1892029)
  %res.i1371 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1379, <8 x double> %value_phi1912030)
  %res.i1368 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1379, <8 x double> %value_phi1932031)
  %230 = add nuw nsw i64 %value_phi1622015, 1
  %exitcond2127 = icmp eq i64 %230, %29
  br i1 %exitcond2127, label %L1034, label %L930

My idea for a solution was to try inline asm to specifically require vfmadd231pd.

    %res = call <8 x double> asm "vfmadd231pd \$3, \$2, \$0", "=x,0,x,x"(<8 x double> %2, <8 x double> %1, <8 x double> %0)
    ret <8 x double> %res

However, this caused a roughly 50% performance regression, the same example session now looks like this:

L3040:
        vmovupd (%rsi,%rax), %zmm1 {%k1} {z}
        vbroadcastsd    (%r11,%rcx,8), %zmm5
        vmovups (%rsi,%rdi,8), %zmm4
        movq    208(%rsp), %rbx
        vmovups (%rsi,%rbx,8), %zmm3
        movq    248(%rsp), %rbx
        vmovups (%rsi,%rbx,8), %zmm2
        movq    %r15, 160(%rsp)
        movq    %r13, 152(%rsp)
        movq    %r9, 144(%rsp)
        vfmadd231pd     %zmm4, %zmm5, %zmm14 # zmm14 = (zmm5 * zmm4) + zmm14
        vfmadd231pd     %zmm3, %zmm5, %zmm0 # zmm0 = (zmm5 * zmm3) + zmm0
        vmovaps %zmm0, 768(%rsp)
        vmovaps 704(%rsp), %zmm0
        vfmadd231pd     %zmm2, %zmm5, %zmm0 # zmm0 = (zmm5 * zmm2) + zmm0
        vmovaps %zmm0, 704(%rsp)
        vfmadd231pd     %zmm1, %zmm5, %zmm15 # zmm15 = (zmm5 * zmm1) + zmm15
        vbroadcastsd    (%r14,%rcx,8), %zmm5
        vfmadd231pd     %zmm4, %zmm5, %zmm13 # zmm13 = (zmm5 * zmm4) + zmm13
        vfmadd231pd     %zmm3, %zmm5, %zmm12 # zmm12 = (zmm5 * zmm3) + zmm12
        vfmadd231pd     %zmm2, %zmm5, %zmm11 # zmm11 = (zmm5 * zmm2) + zmm11
        vfmadd231pd     %zmm1, %zmm5, %zmm10 # zmm10 = (zmm5 * zmm1) + zmm10
        vbroadcastsd    (%r12,%rcx,8), %zmm5
        vfmadd231pd     %zmm4, %zmm5, %zmm9 # zmm9 = (zmm5 * zmm4) + zmm9
        vfmadd231pd     %zmm3, %zmm5, %zmm8 # zmm8 = (zmm5 * zmm3) + zmm8
        vfmadd231pd     %zmm2, %zmm5, %zmm7 # zmm7 = (zmm5 * zmm2) + zmm7
        vfmadd231pd     %zmm1, %zmm5, %zmm6 # zmm6 = (zmm5 * zmm1) + zmm6
        vpbroadcastq    (%r10,%rcx,8), %zmm5
        vmovapd %zmm15, %zmm0
        vmovaps %zmm14, %zmm15
        vmovaps %zmm13, %zmm14
        vmovapd %zmm12, %zmm13
        vmovaps %zmm11, %zmm12
        vmovapd %zmm10, %zmm11
        vmovapd %zmm9, %zmm10
        vmovapd %zmm8, %zmm9
        vmovapd %zmm7, %zmm8
        vmovapd %zmm6, %zmm7
        vmovaps 576(%rsp), %zmm6
        vfmadd231pd     %zmm4, %zmm5, %zmm6 # zmm6 = (zmm5 * zmm4) + zmm6
        vmovaps %zmm6, 576(%rsp)
        vmovapd %zmm7, %zmm6
        vmovapd %zmm8, %zmm7
        vmovapd %zmm9, %zmm8
        vmovapd %zmm10, %zmm9
        vmovapd %zmm11, %zmm10
        vmovaps %zmm12, %zmm11
        vmovapd %zmm13, %zmm12
        vmovaps %zmm14, %zmm13
        vmovaps %zmm15, %zmm14
        vmovapd %zmm0, %zmm15
        vmovapd 768(%rsp), %zmm0
        vmovaps 512(%rsp), %zmm4
        vfmadd231pd     %zmm3, %zmm5, %zmm4 # zmm4 = (zmm5 * zmm3) + zmm4
        vmovapd %zmm4, 512(%rsp)
        vmovaps 384(%rsp), %zmm3
        vfmadd231pd     %zmm2, %zmm5, %zmm3 # zmm3 = (zmm5 * zmm2) + zmm3
        vmovapd %zmm3, 384(%rsp)
        vmovaps 256(%rsp), %zmm2
        vfmadd231pd     %zmm1, %zmm5, %zmm2 # zmm2 = (zmm5 * zmm1) + zmm2
        vmovapd %zmm2, 256(%rsp)
        incq    %rcx
        addq    %r8, %rsi
        cmpq    %rcx, %rdx
        jne     L3040

Sure enough, 16/16 vfmadd instructions are now vfmadd231pd, but this didn’t exactly decrease the amount of shuffling between registers.
Needless to say I’ll take the 3 vmovapd over that mess. (The CPU may also be able to elide a small number due to register renaming?)

Does anyone have advice or suggestions on how I can optimize code like this?

This only seems to be happening with the larger blocks, making me wonder if LLVM is overestimating how many registers this code requires and spilling? Could LLVM thinks it has to reserve 1 more register per fma than is actually required?

E.g., a 4x3 block is fine. This requires 12 + 4 + 1 = 17 registers at a given time. If using asm call confuses it into thinking 24 + 4 + 1 = 29 <= 32, it’d still be fine.
But in the 4x4 example I’ve been showing here, that would mean instead of correctly assessing 16 + 4 + 1 = 21 <= 32, it’d think it requires 32 + 4 + 1 > 32 → spill. Could the problem be something like this?

I’ve put up a patch here that should fix your issue https://reviews.llvm.org/D75016

At least of the problem with your inline assembly is that you need to use ‘v’ instead of ‘x’ in the constraint string. ‘x’ only allows zmm0-zmm15. ‘v’ is all evex encodable registers.

1 Like

Your revision has now been accepted, thank you so much, that’s incredible!

‘x’ only allows zmm0-zmm15. ‘v’ is all evex encodable registers.

Thank you, that fixes the problem! E.g.:

L2688:
        vmovupd (%rsi,%r11), %zmm17 {%k1} {z}
        movq    232(%rsp), %rax
        vbroadcastsd    (%rax,%rdi,8), %zmm18
        vmovdqu64       (%rsi,%rbx,8), %zmm19
        movq    32(%rsp), %r13
        vmovdqu64       (%rsi,%r13,8), %zmm20
        movq    112(%rsp), %rax
        vmovups (%rsi,%rax,8), %zmm21
        movq    %r14, 160(%rsp)
        movq    %r9, 152(%rsp)
        movq    %r15, 144(%rsp)
        vfmadd231pd     %zmm19, %zmm18, %zmm13 # zmm13 = (zmm18 * zmm19) + zmm13
        vfmadd231pd     %zmm20, %zmm18, %zmm16 # zmm16 = (zmm18 * zmm20) + zmm16
        vfmadd231pd     %zmm21, %zmm18, %zmm15 # zmm15 = (zmm18 * zmm21) + zmm15
        vfmadd231pd     %zmm17, %zmm18, %zmm14 # zmm14 = (zmm18 * zmm17) + zmm14
        vbroadcastsd    (%rcx,%rdi,8), %zmm18
        vfmadd231pd     %zmm19, %zmm18, %zmm12 # zmm12 = (zmm18 * zmm19) + zmm12
        vfmadd231pd     %zmm20, %zmm18, %zmm11 # zmm11 = (zmm18 * zmm20) + zmm11
        vfmadd231pd     %zmm21, %zmm18, %zmm10 # zmm10 = (zmm18 * zmm21) + zmm10
        vfmadd231pd     %zmm17, %zmm18, %zmm9 # zmm9 = (zmm18 * zmm17) + zmm9
        vbroadcastsd    (%r12,%rdi,8), %zmm18
        vfmadd231pd     %zmm19, %zmm18, %zmm8 # zmm8 = (zmm18 * zmm19) + zmm8
        vfmadd231pd     %zmm20, %zmm18, %zmm7 # zmm7 = (zmm18 * zmm20) + zmm7
        vfmadd231pd     %zmm21, %zmm18, %zmm6 # zmm6 = (zmm18 * zmm21) + zmm6
        vfmadd231pd     %zmm17, %zmm18, %zmm5 # zmm5 = (zmm18 * zmm17) + zmm5
        vpbroadcastq    (%r10,%rdi,8), %zmm18
        vfmadd231pd     %zmm19, %zmm18, %zmm4 # zmm4 = (zmm18 * zmm19) + zmm4
        vfmadd231pd     %zmm20, %zmm18, %zmm3 # zmm3 = (zmm18 * zmm20) + zmm3
        vfmadd231pd     %zmm21, %zmm18, %zmm2 # zmm2 = (zmm18 * zmm21) + zmm2
        vfmadd231pd     %zmm17, %zmm18, %zmm1 # zmm1 = (zmm18 * zmm17) + zmm1
        incq    %rdi
        addq    %r8, %rsi
        cmpq    %rdi, %rdx
        jne     L2688

It does look a little funny to see that out of 4 moves, 2 are vmovdqu64, 1 is vmovups, and the masked move is vmovupd.

I should have realized that zmm15 was the largest-numbered register showing up in that code.

Some playing around has shown that I need to be very limited in using the asm instead of the intrinsic. LLVM will often fuse the intrinsic with other operations, like memory loads or selects. But it of course will not fuse the asm.
So it’ll be a nice upgrade when I no longer need to rely on the asm, and can delete the hackish logic that decides between it and the intrinsic.