AVX 内在澄清,4x4 矩阵乘法奇数
Posted
技术标签:
【中文标题】AVX 内在澄清,4x4 矩阵乘法奇数【英文标题】:AVX Intrinsic Clarification, 4x4 Matrix Multiplication Oddities 【发布时间】:2017-03-23 14:32:09 【问题描述】:我在纸上画出了这个算法的长形式,在纸上它应该可以正常工作。我是在寄存器转换 (256/128/256) 方面遇到了微妙的问题,还是我实际上在某个地方弄乱了算法结构?
为方便起见,我将原版代码和 AVX 代码放在 Godbolt 查看器上,以便您可以随意查看生成的程序集。
标准代码 https://godbolt.org/g/v47RKH
我的 AVX 尝试 1: https://godbolt.org/g/oH1DpO
我的 AVX 尝试 2: https://godbolt.org/g/QFtdKr(缩短了 5 个周期并减少了铸造需求,更易于阅读)
奇怪的是,SSE 代码使用了标量运算,这让我大吃一惊,因为这绝对可以通过水平广播、muls 和 add 来加速。我正在尝试将这个概念提升一个层次。
RHS 永远不需要改变,但本质上如果 LHS 是 a, b, ..., p, 并且 LHS 是 1, 2, ..., 16,那么我们只需要 2 个寄存器来保存 RHS 的 2 个半部分,然后需要 2 个寄存器来保存 LHS 的给定行,格式为 a, a, a, a , b, b, b, b 和 c, c, c, c, d, d, d, d。这是通过 2 次广播和 256/128/256 演员来实现的。
我们得到的中间结果
a*1, a*2, a*3, a*4, b*5, b*6, b*7, b*8 => 行[0]
和
c*9, c*10, c*11, c*12, d*13, d*14, d*15, d*16 => 行[1]
这是一次 w.r.t LHS 展开,所以我们生成
e*1, ... f*8, g*9, ... h*16 => 行[2], 行[3]
接下来将 r0,r1 和 r2,r3 相加(保持 r0 和 r2 作为当前中间体)
最后,将row[0]的高半部分提取到resHalf的低半部分,将row[2]的低半部分插入resHalf的高半部分,将row[2]的高半部分插入到resHalf的高半部分row[0],然后将 row[0] 添加到 resHalf。
无论如何,这应该让我们在迭代 i = 0 结束时得到 resHalf[0] 等于以下内容
a*1 + b*2 + c*3 + d*4, a*5 + b*6 + c*7 + d*8,
a*9 + b*10 + c*11 + d*12, a*13 + b*14 + c*15 + d*16,
e*1 + ... + h*4, e*5 + ... + h*8,
e*9 + ... + h*12, e*13 + ... + h*16
然而,我的算法产生的结果如下:
2x a*1 + c*3, a*5 + c*7, a*9 + c*11, a*13 + c*15,
2x e*1 + g*3, e*5 + g*7, e*9 + g*11, e*13 + g*15
更可怕的是,如果我在三元条件中交换 rhsHolders[0/1],它根本不会改变结果。就好像编译器忽略了其中一个交换和添加。 Clang 4 和 GCC 7 都这样做,那么我在哪里搞砸了?
编辑:输出应该是 4 行 10, 26, 42, 58,但我得到 4, 12, 20, 28
【问题讨论】:
【参考方案1】:奇怪的是,SSE 代码使用了标量运算,这让我大吃一惊,因为这绝对可以通过水平广播、muls 和添加来加速。
你的意思是编译器生成的汇编代码吗? clang4.0 和 gcc7.1 输出中MatMul()
中的所有 AVX 指令都在 ymm 向量上运行。除了 clang 愚蠢的广播加载:它执行标量加载,然后执行单独的 AVX2 广播指令,这非常糟糕,因为英特尔 CPU 将广播加载作为单 uop ALU 指令处理。加载端口本身可以进行广播。但如果源是寄存器,则需要一个 ALU uop 用于 shuffle 端口。
vmovss xmm5, dword ptr [rdi + 24] # xmm5 = mem[0],zero,zero,zero
vbroadcastss xmm5, xmm5
clang 的实际输出(上图)与 gcc 使用的 AVX1 vbroadcastss xmm5, [rdi + 24]
相比真的很傻。
在main()
中,clang 确实发出标量操作。
由于您的输入矩阵都是编译时常量,唯一的谜团是为什么它没有优化到 cout << "a long string with the numbers already formatted\n";
,或者至少优化掉所有的数学运算而只有 @ 987654327@ 结果准备打印。 (是的,它们正在使用vcvtss2sd
在打印循环中从float
转换为double
。)
它通过一些内在的洗牌和数学进行优化,在编译时进行。我猜clang在洗牌的某个地方迷路了,仍然发出了一些数学运算。它们是标量的事实可能表明它在编译时没有做很多工作,但它没有对事物进行重新排序以对其进行矢量化。
请注意,有些常量不会出现在源代码中,它们在内存中也不是升序排列的。
...
.LCPI1_5:
.long 1092616192 # float 10
.LCPI1_6:
.long 1101004800 # float 20
.LCPI1_7:
.long 1098907648 # float 16
...
clang 如何将浮点值放在位模式的整数表示之后的注释中,这真是太好了。
还是我真的把算法结构搞砸了?
嗯,这部分实现看起来完全是假的。您从 rows[j]
初始化 lowerHalf
,然后在下一条语句中覆盖该值。
__m128 lowerHalf = _mm256_castps256_ps128(rows[j]);
lowerHalf = _mm_broadcast_ss(&lhs[offset+2*j]);
然后你做一个 256b 乘以 rows[j]
undefined 的上部 128b 通道。
rows[j] = _mm256_castps128_ps256(lowerHalf);
rows[j] = _mm256_mul_ps(rows[j], (chooser) ? rhsHolders[0] : rhsHolders[1]);
在来自 gcc 和 clang 的 asm 中,上面的通道全为零(因为它们明显选择使用由标量最后写入的 ymm 寄存器 -> xmm 广播,隐式零扩展到最大向量宽度) .请注意,_mm256_castps128_ps256
不保证零扩展。除非__m128
本身是从 256b 或更宽的向量提取/转换的结果,否则很有可能,但它是未定义的。请参阅How to clear the upper 128 bits of __m256 value?,了解您需要在向量中使用归零的上车道的情况。
无论如何,这意味着您将从 128b 向量乘法 (vmulps xmm, xmm, xmm
) 中得到相同的结果:在这些指令之后,上面的 4 个元素将全部为零(或 NaN)
vbroadcastss xmm0, DWORD PTR [rdi+40]
vmulps ymm0, ymm2, ymm0
这种 asm 输出(来自 gcc7.1)极不可能成为正确 matmul 实现的一部分。
我没有仔细查看你到底想在源代码中做什么,但我认为它不完全是这样。
更可怕的是,如果我在三元条件中交换 rhsHolders[0/1],它根本不会改变结果。就好像编译器忽略了其中一个交换并添加了一样。
如果更改源代码中的某些内容不会在 asm 输出中产生您期望的更改,这表明您可能弄错了源代码,并且某些内容正在优化。有时我复制/粘贴一个内在变量并忘记在新行中更改输入变量,因此我的函数忽略了它的一些计算结果并使用了另一个计算结果两次。
【讨论】:
【参考方案2】:这几乎是我昨天对 SO 的回答的复制和粘贴 :)
试试这个
void MatMul(const float* __restrict lhs , const float* __restrict rhs , float* __restrict out )
lhs = reinterpret_cast<float*>(__builtin_assume_aligned (lhs, 32));
rhs = reinterpret_cast<float*>(__builtin_assume_aligned (rhs, 32));
out = reinterpret_cast<float*>(__builtin_assume_aligned (out, 32));
for(int i = 0; i < 4; i++)
for(int j = 0; j < 4; j++)
for (int k = 0; k < 4; k++)
out[i*4 + j] += lhs[i*4 + k]*rhs[k*4 + i];
使用以下之一编译(衡量哪个对您来说最快)
-O3 -mavx
-O3 -mavx2
-O3 -mavx2 -mfma
-O3 -mavx2 -mfma -ffast-math
这在 GCC 下有效(我的意思是矢量化),cLANG 由于某种原因无法这样做。 GCC 也会展开循环。
【讨论】:
非常酷,但我仍然想知道我是如何填充我的算法的。此外,为您的代码生成的 GCC 7 非常浪费。有一个很好的 10 个循环,不需要在那里。 什么迭代? 它已展开,但有 2 个不同的部分,一个较大循环的 2 次迭代。 我添加了第二次尝试,向底部清理并刮了 5 个周期。 如果你想要的话,你应该使用-march=haswell
。您的选项不会覆盖默认的-mtune=generic
,并且会使用 vinsertf128 / vextractf128 将未对齐的 256b 加载/存储分成 128b 的一半。 (您可以使用 -mavx2 -mfma -ffast-math -mtune=haswell
调整 haswell,而无需启用 BMI2、popcnt 和其他 haswell 支持的 ISA 扩展,超出您使用 -m
的支持。)以上是关于AVX 内在澄清,4x4 矩阵乘法奇数的主要内容,如果未能解决你的问题,请参考以下文章