在AVX2中重现_mm256_sllv_epi16和_mm256_sllv_epi8

Posted

技术标签:

【中文标题】在AVX2中重现_mm256_sllv_epi16和_mm256_sllv_epi8【英文标题】:Reproduce _mm256_sllv_epi16 and _mm256_sllv_epi8 in AVX2 【发布时间】:2018-08-10 15:24:43 【问题描述】:

我很惊讶地发现 _mm256_sllv_epi16/8(__m256i v1, __m256i v2)_mm256_srlv_epi16/8(__m256i v1, __m256i v2) 不在 Intel Intrinsics Guide 中,而且我找不到任何解决方案来仅使用 AVX2 重新创建 AVX512 内在函数。

此函数将所有 16/8bits packed int 左移 v2 中相应数据元素的计数值。

epi16 示例:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
v1 = _mm256_sllv_epi16(v1, v2);

那么 v1 等于 -> (1111111111111111, 1111111111111110, 1111111111111100, 1111111111111000, ...................., 1000000000000000);

【问题讨论】:

@1201ProgramAlarm: true,但是 OP 想用 AVX2 模拟它们,所以它们的代码可以在 Haswell / Ryzen 上运行,而不仅仅是 AVX512BW (SKX)。并且没有 CPU 有 _mm256_sllv_epi8 / vpsllvb 因为它不存在,甚至在 AVX512VBMI2 中也不存在。我删除了 avx512 标签,因为这不是 avx512 问题。 【参考方案1】:

_mm256_sllv_epi8 的情况下,使用pshufb 指令作为一个小查找表,通过乘法替换移位并不难。也可以用乘法和许多其他指令来模拟_mm256_srlv_epi8 的右移,请参见下面的代码。我希望至少 _mm256_sllv_epi8 比 Nyan 的 solution 更有效。


可以使用或多或少相同的想法来模拟_mm256_sllv_epi16,但在这种情况下,选择正确的乘数就不那么简单了(另请参见下面的代码)。

下面的解决方案_mm256_sllv_epi16_emu 不一定比 Nyan 的solution 更快,也不一定更好。 性能取决于周围的代码和使用的 CPU。 尽管如此,这里的解决方案可能会很有趣,至少在较旧的计算机系统上是这样。 例如,vpsllvd 指令在 Nyan 的解决方案中使用了两次。此指令在 Intel Skylake 系统或更新版本上速度很快。 在 Intel Broadwell 或 Haswell 上,这条指令很慢,因为它解码为 3 个微操作。这里的解决方案避免了这种缓慢的指令。

如果已知移位计数小于或等于 15,则可以使用 mask_lt_15 跳过这两行代码。

缺少内在的_mm256_srlv_epi16 留给读者作为练习。


/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell shift_v_epi8.c     */
#include <immintrin.h>
#include <stdio.h>
int print_epi8(__m256i  a);
int print_epi16(__m256i  a);

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) 
    __m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);

    __m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i << n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
    __m256i x_lo           = _mm256_mullo_epi16(a, multiplier);               /* Unfortunately _mm256_mullo_epi8 doesn't exist. Split the 16 bit elements in a high and low part. */

    __m256i multiplier_hi  = _mm256_srli_epi16(multiplier, 8);                /* The multiplier of the high bits.                                                                 */
    __m256i a_hi           = _mm256_and_si256(a, mask_hi);                    /* Mask off the low bits.                                                                           */
    __m256i x_hi           = _mm256_mullo_epi16(a_hi, multiplier_hi);
    __m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
            return x;



__m256i _mm256_srlv_epi8(__m256i a, __m256i count) 
    __m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128, 0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128);

    __m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i >> n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
    __m256i a_lo           = _mm256_andnot_si256(mask_hi, a);                 /* Mask off the high bits.                                                                          */
    __m256i multiplier_lo  = _mm256_andnot_si256(mask_hi, multiplier);        /* The multiplier of the low bits.                                                                  */
    __m256i x_lo           = _mm256_mullo_epi16(a_lo, multiplier_lo);         /* Shift left a_lo by multiplying.                                                                  */
            x_lo           = _mm256_srli_epi16(x_lo, 7);                      /* Shift right by 7 to get the low bits at the right position.                                      */

    __m256i multiplier_hi  = _mm256_and_si256(mask_hi, multiplier);           /* The multiplier of the high bits.                                                                 */
    __m256i x_hi           = _mm256_mulhi_epu16(a, multiplier_hi);            /* Variable shift left a_hi by multiplying. Use a instead of a_hi because the a_lo bits don't interfere */
            x_hi           = _mm256_slli_epi16(x_hi, 1);                      /* Shift left by 1 to get the high bits at the right position.                                      */
    __m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
            return x;



__m256i _mm256_sllv_epi16_emu(__m256i a, __m256i count) 
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
    __m256i byte_shuf_mask = _mm256_set_epi8(14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0, 14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0);

    __m256i mask_lt_15     = _mm256_cmpgt_epi16(_mm256_set1_epi16(16), count);
            a              = _mm256_and_si256(mask_lt_15, a);                    /* Set a to zero if count > 15.                                                                      */
            count          = _mm256_shuffle_epi8(count, byte_shuf_mask);         /* Duplicate bytes from the even postions to bytes at the even and odd positions.                    */
            count          = _mm256_sub_epi8(count,_mm256_set1_epi16(0x0800));   /* Subtract 8 at the even byte positions. Note that the vpshufb instruction selects a zero byte if the shuffle control mask is negative.     */
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count);         /* Select the right multiplication factor in the lookup table. Within the 16 bit elements, only the upper byte or the lower byte is nonzero. */
    __m256i x              = _mm256_mullo_epi16(a, multiplier);                  
            return x;



int main()

    printf("Emulating _mm256_sllv_epi8:\n");
    __m256i a     = _mm256_set_epi8(32,31,30,29, 28,27,26,25, 24,23,22,21, 20,19,18,17, 16,15,14,13, 12,11,10,9, 8,7,6,5, 4,3,2,1);
    __m256i count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
    __m256i x     = _mm256_sllv_epi8(a, count);
    printf("a     = \n"); print_epi8(a    );
    printf("count = \n"); print_epi8(count);
    printf("x     = \n"); print_epi8(x    );
    printf("\n\n"); 


    printf("Emulating _mm256_srlv_epi8:\n");
            a     = _mm256_set_epi8(223,224,225,226, 227,228,229,230, 231,232,233,234, 235,236,237,238, 239,240,241,242, 243,244,245,246, 247,248,249,250, 251,252,253,254);
            count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
            x     = _mm256_srlv_epi8(a, count);
    printf("a     = \n"); print_epi8(a    );
    printf("count = \n"); print_epi8(count);
    printf("x     = \n"); print_epi8(x    );
    printf("\n\n"); 



    printf("Emulating _mm256_sllv_epi16:\n");
            a     = _mm256_set_epi16(1601,1501,1401,1301, 1200,1100,1000,900, 800,700,600,500, 400,300,200,100);
            count = _mm256_set_epi16(17,16,15,13,  11,10,9,8, 7,6,5,4, 3,2,1,0);
            x     = _mm256_sllv_epi16_emu(a, count);
    printf("a     = \n"); print_epi16(a    );
    printf("count = \n"); print_epi16(count);
    printf("x     = \n"); print_epi16(x    );
    printf("\n\n"); 

    return 0;



int print_epi8(__m256i  a)
  char v[32];
  int i;
  _mm256_storeu_si256((__m256i *)v,a);
  for (i = 0; i<32; i++) printf("%4hhu",v[i]);
  printf("\n");
  return 0;


int print_epi16(__m256i  a)
  unsigned short int  v[16];
  int i;
  _mm256_storeu_si256((__m256i *)v,a);
  for (i = 0; i<16; i++) printf("%6hu",v[i]);
  printf("\n");
  return 0;

输出是:

Emulating _mm256_sllv_epi8:
a     = 
   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32
count = 
   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
   1   4  12  32  80 192 192   0   0   0   0   0  13  28  60 128  16  64 192   0   0   0   0   0  25  52 108 224 208 192 192   0


Emulating _mm256_srlv_epi8:
a     = 
 254 253 252 251 250 249 248 247 246 245 244 243 242 241 240 239 238 237 236 235 234 233 232 231 230 229 228 227 226 225 224 223
count = 
   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
 254 126  63  31  15   7   3   1   0   0   0   0 242 120  60  29  14   7   3   1   0   0   0   0 230 114  57  28  14   7   3   1


Emulating _mm256_sllv_epi16:
a     = 
   100   200   300   400   500   600   700   800   900  1000  1100  1200  1301  1401  1501  1601
count = 
     0     1     2     3     4     5     6     7     8     9    10    11    13    15    16    17
x     = 
   100   400  1200  3200  8000 19200 44800 36864 33792 53248 12288 32768 40960 32768     0     0

确实缺少一些 AVX2 指令。 但是,请注意,通过模拟“缺失”的 AVX2 指令来填补这些空白并不总是一个好主意。有时它是 以避免这些模拟指令的方式重新设计代码更有效。例如,通过使用更宽的向量 元素(_epi32 而不是 _epi16),具有原生支持。

【讨论】:

我们可以利用vpmaddubsw 为我们做一些掩蔽吗?不,我们需要以两种方式屏蔽 vpshufb 结果以在每个其他元素中创建 0。对于每对的高半部分,我们需要一个 256 * n 的乘数,但这不适合一个字节。 @PeterCordes 我认为vpmaddubsw 是个好主意,谢谢!可能有可能使用该指令改进_mm256_srlv_epi8x_lo 的计算(去掉一个andnot)。 哦,是的,我没有考虑右移,但是将奇数和偶数元素扩展到 16 位对的底部一点也不差。立即使用vpsllw 将高字节放回顶部,我们确实至少提前了一条指令,对吗?但请注意,Haswell 仅在 p0 上运行移位和乘法,因此它们相互竞争。或者,如果 shift/mul 压力比 shuffle 压力差,我们可以使用vpslldq @PeterCordes 是的,但是在使用vpmaddubsw 时,至少应该屏蔽掉其他字节元素中的一个,所以这并不总是一个胜利。稍后我会回到它... 也许vpmaddubsw 用于下半部分,vpmulhuw 用于上半部分?是的,我认为这是一种直接替换,可以节省一个 andnot【参考方案2】:

奇怪的是,他们错过了这一点,尽管似乎许多 AVX 整数指令仅适用于 32/64 位宽度。 AVX512BW 中至少添加了 16 位(尽管我仍然不明白英特尔为什么拒绝添加 8 位移位)。

我们可以仅使用 AVX2 来模拟 16 位变量移位,方法是使用 32 位变量移位和一些掩码和混合。

我们需要在包含每个 16 位元素的 32 位元素的底部进行右移计数,我们可以使用 AND(用于低位元素)和立即移位用于高半部分。 (与标量移位不同,x86 向量移位使它们的计数饱和,而不是包装/屏蔽)。

在进行高半移位之前,我们还需要屏蔽低 16 位数据,因此我们不会将垃圾转移到包含 32 位元素的高 16 位一半。

__m256i _mm256_sllv_epi16(__m256i a, __m256i count) 
    const __m256i mask = _mm256_set1_epi32(0xffff0000);
    __m256i low_half = _mm256_sllv_epi32(
        a,
        _mm256_andnot_si256(mask, count)
    );
    __m256i high_half = _mm256_sllv_epi32(
        _mm256_and_si256(mask, a),
        _mm256_srli_epi32(count, 16)
    );
    return _mm256_blend_epi16(low_half, high_half, 0xaa);

__m256i _mm256_sllv_epi16(__m256i a, __m256i count) 
    const __m256i mask = _mm256_set1_epi32(0xffff0000); // alternating low/high words of a dword
    // shift low word of each dword: low_half = (a << (count & 0xffff)) [for each 32b element]
    // note that, because `a` isn't being masked here, we may get some "junk" bits, but these will get eliminated by the blend below
    __m256i low_half = _mm256_sllv_epi32(
        a,
        _mm256_andnot_si256(mask, count)
    );
    // shift high word of each dword: high_half = ((a & 0xffff0000) << (count >> 16)) [for each 32b element]
    __m256i high_half = _mm256_sllv_epi32(
        _mm256_and_si256(mask, a),     // make sure we shift in zeros
        _mm256_srli_epi32(count, 16)   // need the high-16 count at the bottom of a 32-bit element
    );
    // combine low and high words
    return _mm256_blend_epi16(low_half, high_half, 0xaa);


__m256i _mm256_srlv_epi16(__m256i a, __m256i count) 
    const __m256i mask = _mm256_set1_epi32(0x0000ffff);
    __m256i low_half = _mm256_srlv_epi32(
        _mm256_and_si256(mask, a),
        _mm256_and_si256(mask, count)
    );
    __m256i high_half = _mm256_srlv_epi32(
        a,
        _mm256_srli_epi32(count, 16)
    );
    return _mm256_blend_epi16(low_half, high_half, 0xaa);

GCC 8.2 将其编译为您所期望的或多或少:

_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
        vmovdqa       ymm3, YMMWORD PTR .LC0[rip]
        vpand   ymm2, ymm0, ymm3
        vpand   ymm3, ymm1, ymm3
        vpsrld  ymm1, ymm1, 16
        vpsrlvd ymm2, ymm2, ymm3
        vpsrlvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret
_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
        vmovdqa       ymm3, YMMWORD PTR .LC1[rip]
        vpandn  ymm2, ymm3, ymm1
        vpsrld  ymm1, ymm1, 16
        vpsllvd ymm2, ymm0, ymm2
        vpand   ymm0, ymm0, ymm3
        vpsllvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret

表示仿真结果为 1x 负载 + 2x AND/ANDN + 2x 可变移位 + 1x 右移 + 1x 混合。

Clang 6.0 做了一些有趣的事情 - 它通过使用混合消除了内存负载(和相应的屏蔽):

_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
        vpxor   xmm2, xmm2, xmm2
        vpblendw        ymm3, ymm1, ymm2, 170
        vpsllvd ymm3, ymm0, ymm3
        vpsrld  ymm1, ymm1, 16
        vpblendw        ymm0, ymm2, ymm0, 170
        vpsllvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm3, ymm0, 170
        ret
_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
        vpxor   xmm2, xmm2, xmm2
        vpblendw        ymm3, ymm0, ymm2, 170
        vpblendw        ymm2, ymm1, ymm2, 170
        vpsrlvd ymm2, ymm3, ymm2
        vpsrld  ymm1, ymm1, 16
        vpsrlvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret

这导致:1x 清晰 + 3x 混合 + 2x 可变移位 + 1x 右移位。

我没有对哪种方法更快进行任何基准测试,但我怀疑它可能取决于 CPU,特别是 CPU 上 PBLENDW 的成本。

当然,如果您的用例受到更多限制,上述内容可以简化,例如如果您的移位量都是常量,您可以删除使其工作所需的屏蔽/移位(假设编译器不会自动为您执行此操作)。 对于左移,如果移位量是恒定的,您可以改用_mm256_mullo_epi16,将移位量转换为可以相乘的值,例如对于您给出的示例:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(1<<0,1<<1,1<<2,1<<3,1<<4,1<<5,1<<6,1<<7,1<<8,1<<9,1<<10,1<<11,1<<12,1<<13,1<<14,1<<15);
v1 = _mm256_mullo_epi16(v1, v2);

更新:Peter 提到(见下面的评论)右移也可以用_mm256_mulhi_epi16 实现(例如,执行v&gt;&gt;1 乘以v 乘以1&lt;&lt;15 并取高字)。


对于 8 位变量移位,这在 AVX512 中也不存在(同样,我不知道为什么英特尔没有 8 位 SIMD 移位)。 如果 AVX512BW 可用,您可以使用与上述类似的技巧,使用 _mm256_sllv_epi16对于 AVX2,我想不出比第二次应用 16 位仿真更好的方法,因为您最终必须将 32 位移位提供给您的移位操作 4 倍。 请参阅@wim 的回答,了解 AVX2 中 8 位的一个很好的解决方案。

这是我想出的(AVX512上8位基本上采用16位版本):

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) 
    const __m256i mask = _mm256_set1_epi16(0xff00);
    __m256i low_half = _mm256_sllv_epi16(
        a,
        _mm256_andnot_si256(mask, count)
    );
    __m256i high_half = _mm256_sllv_epi16(
        _mm256_and_si256(mask, a),
        _mm256_srli_epi16(count, 8)
    );
    return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));


__m256i _mm256_srlv_epi8(__m256i a, __m256i count) 
    const __m256i mask = _mm256_set1_epi16(0x00ff);
    __m256i low_half = _mm256_srlv_epi16(
        _mm256_and_si256(mask, a),
        _mm256_and_si256(mask, count)
    );
    __m256i high_half = _mm256_srlv_epi16(
        a,
        _mm256_srli_epi16(count, 8)
    );
    return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));

(Peter Cordes 在下面提到,在纯 AVX512BW(+VL) 实现中,_mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00)) 可以替换为 _mm256_mask_blend_epi8(0xaaaaaaaa, low_half, high_half),这可能更快)

【讨论】:

vmovdqa64: 你编译时启用了 AVX512。没关系,因为看起来您没有使用任何需要 AVX512 的内在函数。如果您要显示 asm 输出,最好在 godbolt.org 上包含代码的永久链接,这样人们就可以自己动手玩了。 (使用full-link to prevent any link-rot,而不是短链接)。例如How to convert 32-bit float to 8-bit signed char?. 如果shift-count向量被重复使用很多次,可以预先计算1&lt;&lt; count_vec并使用vpmullw乘以2的适当幂。对于右移,你可以做类似的事情vpmulhw. 如果您使用 AVX512BW 进行 16 位可变移位,请使用 vpblendmb 进行字节混合 (1 uop),并使用 0 位和 1 位交替的掩码寄存器。它比 AVX2 vpblendvb (2 微秒) 更有效。请在 reviews.llvm.org/D50074 上查看我的 cmets。希望 LLVM 在某个时候能够在编译时将 _mm256_blendv_epi8 优化为 vpblendmb,尤其是在使用常量掩码的情况下。 再次掩码:您只需要一个掩码:set1_epi32(0x000000ff),您使用 after 移位而不是之前。 (嗯,您可能会通过使用另一个掩码获得更多 ILP,因此 AND 可以与计数向量的第一个移位并行运行。但最多 2 个掩码向量似乎是个好主意。) re:PBLENDW 的成本:它是所有 AVX2 CPU (agner.org/optimize) 上的单 uop / 1c 延迟。但是英特尔 CPU 仅在一个端口 (p5) 上运行它,因此它可能成为吞吐量瓶颈(再次查看我的 cmets 对 LLVM 的评论)。实际上在 AMD CPU 上,256 位版本是 2 微指令,因为像往常一样,它们拆分 256 位运算。即使对于最便宜的矢量微指令,Bulldozer 系列也有 2c 的延迟。但是 AMD 对 pblendw 有很好的吞吐量。

以上是关于在AVX2中重现_mm256_sllv_epi16和_mm256_sllv_epi8的主要内容,如果未能解决你的问题,请参考以下文章

有没有办法用 AVX2 编写 _mm256_shldi_epi8(a,b,1) ? (向量之间每 8 位元素移位一位)

AVX2 的汇编错误

AVX2:U8绝对差

AVX2 1x mm256i 32bit 到 2x mm256i 64bit

用 AVX2 有条件地选择一个常数值

如何处理 SIGSEGV,Segmentation fault。使用 Avx2 时