I’m developing a library that models simple loops and lowers them while making heave use of llvm intrinsics. When I use @llvm.fma.v8f64
, it often likes to generate instructions other than vfmadd231pd
, e.g. “vfmadd213pd”, which requires a vmovapd
instruction to also be generated.
Here is an example:
L2208:
vmovapd %zmm6, %zmm18
vmovapd %zmm4, %zmm19
vmovapd %zmm3, %zmm20
vmovupd (%rsi,%rax,8), %zmm6
vmovupd (%rsi,%rbp,8), %zmm4
vmovupd (%rsi,%rdx,8), %zmm3
vmovupd (%rsi,%rcx), %zmm21 {%k1} {z}
vbroadcastsd (%r9,%rdi,8), %zmm22
vfmadd231pd %zmm22, %zmm6, %zmm14 # zmm14 = (zmm6 * zmm22) + zmm14
vfmadd231pd %zmm22, %zmm4, %zmm17 # zmm17 = (zmm4 * zmm22) + zmm17
vfmadd231pd %zmm22, %zmm3, %zmm16 # zmm16 = (zmm3 * zmm22) + zmm16
vfmadd231pd %zmm22, %zmm21, %zmm15 # zmm15 = (zmm21 * zmm22) + zmm15
vbroadcastsd (%r11,%rdi,8), %zmm22
vfmadd231pd %zmm22, %zmm6, %zmm13 # zmm13 = (zmm6 * zmm22) + zmm13
vfmadd231pd %zmm22, %zmm4, %zmm12 # zmm12 = (zmm4 * zmm22) + zmm12
vfmadd231pd %zmm22, %zmm3, %zmm11 # zmm11 = (zmm3 * zmm22) + zmm11
vfmadd231pd %zmm22, %zmm21, %zmm10 # zmm10 = (zmm21 * zmm22) + zmm10
vbroadcastsd (%r8,%rdi,8), %zmm22
vfmadd231pd %zmm22, %zmm6, %zmm9 # zmm9 = (zmm6 * zmm22) + zmm9
vfmadd231pd %zmm22, %zmm4, %zmm8 # zmm8 = (zmm4 * zmm22) + zmm8
vfmadd231pd %zmm22, %zmm3, %zmm7 # zmm7 = (zmm3 * zmm22) + zmm7
vfmadd231pd %zmm22, %zmm21, %zmm5 # zmm5 = (zmm21 * zmm22) + zmm5
vbroadcastsd (%r13,%rdi,8), %zmm22
vfmadd213pd %zmm18, %zmm22, %zmm6 # zmm6 = (zmm22 * zmm6) + zmm18
vfmadd213pd %zmm19, %zmm22, %zmm4 # zmm4 = (zmm22 * zmm4) + zmm19
vfmadd213pd %zmm20, %zmm22, %zmm3 # zmm3 = (zmm22 * zmm3) + zmm20
vfmadd231pd %zmm22, %zmm21, %zmm2 # zmm2 = (zmm21 * zmm22) + zmm2
incq %rdi
addq %r10, %rsi
cmpq %rdi, %rbx
jne L2208
Notice that three of the last 4 vfmadd
instructions are 213
instead of 231
. This is problematic, because while the 231
variant accumulates a vector (increments a vector with the product of two others), the 213
s don’t increment the vector, and instead overwrite one of the multiplied elements (specifically, they overwrite the three vmovupd
(unmasked) instructions used in the four blocks of vfmadd
.
That they don’t increment the accumulation vectors necessitates the vmovapd
s at the top of the block.
Adding the fast
flag to the fmuladd instructions exacerbates this problem (especially on AVX2).
Example of the optimized llvm that generated it:
%offsetptr.i1442 = getelementptr inbounds double, double* %typptr.i1784, i64 %222
%ptr.i1443 = bitcast double* %offsetptr.i1442 to <8 x double>*
%res.i1444 = load <8 x double>, <8 x double>* %ptr.i1443, align 8
%223 = add i64 %221, %210
%offsetptr.i1438 = getelementptr inbounds double, double* %typptr.i1784, i64 %223
%ptr.i1439 = bitcast double* %offsetptr.i1438 to <8 x double>*
%res.i1440 = load <8 x double>, <8 x double>* %ptr.i1439, align 8
%224 = add i64 %221, %211
%offsetptr.i1434 = getelementptr inbounds double, double* %typptr.i1784, i64 %224
%ptr.i1435 = bitcast double* %offsetptr.i1434 to <8 x double>*
%res.i1436 = load <8 x double>, <8 x double>* %ptr.i1435, align 8
%225 = add i64 %221, %212
%offsetptr.i1429 = getelementptr inbounds double, double* %typptr.i1784, i64 %225
%ptr.i1430 = bitcast double* %offsetptr.i1429 to <8 x double>*
%res.i1432 = call <8 x double> @llvm.masked.load.v8f64.p0v8f64(<8 x double>* nonnull %ptr.i1430, i32 8, <8 x i1> %213, <8 x double> zeroinitializer)
%226 = add i64 %value_phi1622015, %214
%ptr.i1426 = getelementptr inbounds double, double* %typptr.i1769, i64 %226
%res.i1427 = load double, double* %ptr.i1426, align 8
%ie.i1423 = insertelement <8 x double> undef, double %res.i1427, i32 0
%v.i1424 = shufflevector <8 x double> %ie.i1423, <8 x double> undef, <8 x i32> zeroinitializer
%res.i1422 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1424, <8 x double> %value_phi1632016)
%res.i1419 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1424, <8 x double> %value_phi1652017)
%res.i1416 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1424, <8 x double> %value_phi1672018)
%res.i1413 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1424, <8 x double> %value_phi1692019)
%227 = add i64 %value_phi1622015, %216
%ptr.i1411 = getelementptr inbounds double, double* %typptr.i1769, i64 %227
%res.i1412 = load double, double* %ptr.i1411, align 8
%ie.i1408 = insertelement <8 x double> undef, double %res.i1412, i32 0
%v.i1409 = shufflevector <8 x double> %ie.i1408, <8 x double> undef, <8 x i32> zeroinitializer
%res.i1407 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1409, <8 x double> %value_phi1712020)
%res.i1404 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1409, <8 x double> %value_phi1732021)
%res.i1401 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1409, <8 x double> %value_phi1752022)
%res.i1398 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1409, <8 x double> %value_phi1772023)
%228 = add i64 %value_phi1622015, %218
%ptr.i1396 = getelementptr inbounds double, double* %typptr.i1769, i64 %228
%res.i1397 = load double, double* %ptr.i1396, align 8
%ie.i1393 = insertelement <8 x double> undef, double %res.i1397, i32 0
%v.i1394 = shufflevector <8 x double> %ie.i1393, <8 x double> undef, <8 x i32> zeroinitializer
%res.i1392 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1394, <8 x double> %value_phi1792024)
%res.i1389 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1394, <8 x double> %value_phi1812025)
%res.i1386 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1394, <8 x double> %value_phi1832026)
%res.i1383 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1394, <8 x double> %value_phi1852027)
%229 = add i64 %value_phi1622015, %220
%ptr.i1381 = getelementptr inbounds double, double* %typptr.i1769, i64 %229
%res.i1382 = load double, double* %ptr.i1381, align 8
%ie.i1378 = insertelement <8 x double> undef, double %res.i1382, i32 0
%v.i1379 = shufflevector <8 x double> %ie.i1378, <8 x double> undef, <8 x i32> zeroinitializer
%res.i1377 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1444, <8 x double> %v.i1379, <8 x double> %value_phi1872028)
%res.i1374 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1440, <8 x double> %v.i1379, <8 x double> %value_phi1892029)
%res.i1371 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1436, <8 x double> %v.i1379, <8 x double> %value_phi1912030)
%res.i1368 = call <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i1432, <8 x double> %v.i1379, <8 x double> %value_phi1932031)
%230 = add nuw nsw i64 %value_phi1622015, 1
%exitcond2127 = icmp eq i64 %230, %29
br i1 %exitcond2127, label %L1034, label %L930
My idea for a solution was to try inline asm to specifically require vfmadd231pd
.
%res = call <8 x double> asm "vfmadd231pd \$3, \$2, \$0", "=x,0,x,x"(<8 x double> %2, <8 x double> %1, <8 x double> %0)
ret <8 x double> %res
However, this caused a roughly 50% performance regression, the same example session now looks like this:
L3040:
vmovupd (%rsi,%rax), %zmm1 {%k1} {z}
vbroadcastsd (%r11,%rcx,8), %zmm5
vmovups (%rsi,%rdi,8), %zmm4
movq 208(%rsp), %rbx
vmovups (%rsi,%rbx,8), %zmm3
movq 248(%rsp), %rbx
vmovups (%rsi,%rbx,8), %zmm2
movq %r15, 160(%rsp)
movq %r13, 152(%rsp)
movq %r9, 144(%rsp)
vfmadd231pd %zmm4, %zmm5, %zmm14 # zmm14 = (zmm5 * zmm4) + zmm14
vfmadd231pd %zmm3, %zmm5, %zmm0 # zmm0 = (zmm5 * zmm3) + zmm0
vmovaps %zmm0, 768(%rsp)
vmovaps 704(%rsp), %zmm0
vfmadd231pd %zmm2, %zmm5, %zmm0 # zmm0 = (zmm5 * zmm2) + zmm0
vmovaps %zmm0, 704(%rsp)
vfmadd231pd %zmm1, %zmm5, %zmm15 # zmm15 = (zmm5 * zmm1) + zmm15
vbroadcastsd (%r14,%rcx,8), %zmm5
vfmadd231pd %zmm4, %zmm5, %zmm13 # zmm13 = (zmm5 * zmm4) + zmm13
vfmadd231pd %zmm3, %zmm5, %zmm12 # zmm12 = (zmm5 * zmm3) + zmm12
vfmadd231pd %zmm2, %zmm5, %zmm11 # zmm11 = (zmm5 * zmm2) + zmm11
vfmadd231pd %zmm1, %zmm5, %zmm10 # zmm10 = (zmm5 * zmm1) + zmm10
vbroadcastsd (%r12,%rcx,8), %zmm5
vfmadd231pd %zmm4, %zmm5, %zmm9 # zmm9 = (zmm5 * zmm4) + zmm9
vfmadd231pd %zmm3, %zmm5, %zmm8 # zmm8 = (zmm5 * zmm3) + zmm8
vfmadd231pd %zmm2, %zmm5, %zmm7 # zmm7 = (zmm5 * zmm2) + zmm7
vfmadd231pd %zmm1, %zmm5, %zmm6 # zmm6 = (zmm5 * zmm1) + zmm6
vpbroadcastq (%r10,%rcx,8), %zmm5
vmovapd %zmm15, %zmm0
vmovaps %zmm14, %zmm15
vmovaps %zmm13, %zmm14
vmovapd %zmm12, %zmm13
vmovaps %zmm11, %zmm12
vmovapd %zmm10, %zmm11
vmovapd %zmm9, %zmm10
vmovapd %zmm8, %zmm9
vmovapd %zmm7, %zmm8
vmovapd %zmm6, %zmm7
vmovaps 576(%rsp), %zmm6
vfmadd231pd %zmm4, %zmm5, %zmm6 # zmm6 = (zmm5 * zmm4) + zmm6
vmovaps %zmm6, 576(%rsp)
vmovapd %zmm7, %zmm6
vmovapd %zmm8, %zmm7
vmovapd %zmm9, %zmm8
vmovapd %zmm10, %zmm9
vmovapd %zmm11, %zmm10
vmovaps %zmm12, %zmm11
vmovapd %zmm13, %zmm12
vmovaps %zmm14, %zmm13
vmovaps %zmm15, %zmm14
vmovapd %zmm0, %zmm15
vmovapd 768(%rsp), %zmm0
vmovaps 512(%rsp), %zmm4
vfmadd231pd %zmm3, %zmm5, %zmm4 # zmm4 = (zmm5 * zmm3) + zmm4
vmovapd %zmm4, 512(%rsp)
vmovaps 384(%rsp), %zmm3
vfmadd231pd %zmm2, %zmm5, %zmm3 # zmm3 = (zmm5 * zmm2) + zmm3
vmovapd %zmm3, 384(%rsp)
vmovaps 256(%rsp), %zmm2
vfmadd231pd %zmm1, %zmm5, %zmm2 # zmm2 = (zmm5 * zmm1) + zmm2
vmovapd %zmm2, 256(%rsp)
incq %rcx
addq %r8, %rsi
cmpq %rcx, %rdx
jne L3040
Sure enough, 16/16 vfmadd
instructions are now vfmadd231pd
, but this didn’t exactly decrease the amount of shuffling between registers.
Needless to say I’ll take the 3 vmovapd over that mess. (The CPU may also be able to elide a small number due to register renaming?)
Does anyone have advice or suggestions on how I can optimize code like this?
This only seems to be happening with the larger blocks, making me wonder if LLVM is overestimating how many registers this code requires and spilling? Could LLVM thinks it has to reserve 1 more register per fma than is actually required?
E.g., a 4x3 block is fine. This requires 12 + 4 + 1 = 17 registers at a given time. If using asm call confuses it into thinking 24 + 4 + 1 = 29 <= 32, it’d still be fine.
But in the 4x4 example I’ve been showing here, that would mean instead of correctly assessing 16 + 4 + 1 = 21 <= 32, it’d think it requires 32 + 4 + 1 > 32 → spill. Could the problem be something like this?