_mm512_dpbusd_epi32 AVX-512VNNI 指令的 AVX-512BW 仿真

Posted

技术标签:

【中文标题】_mm512_dpbusd_epi32 AVX-512VNNI 指令的 AVX-512BW 仿真【英文标题】:AVX-512BW emulation of _mm512_dpbusd_epi32 AVX-512VNNI instruction 【发布时间】:2021-06-16 09:04:39 【问题描述】:

从 Cascade Lake Intel CPU 开始有 AVX-512 VNNI 指令,可以加速 CPU 上量化神经网络的推理。 特别是有一个指令_mm512_dpbusd_epi32 (vpdpbusd) 允许执行 8 位有符号和无符号整数的乘法并将它们累加到 32 位整数累加器中。 该指令的伪代码如下:

void _mm512_dpbusd_epi32(int32_t sum[16], uint8_t a[16][4], int8_t b[16][4])

    for(int i = 0; i < 16; ++i)
        sum[i] += 
            (int)a[i][0]*b[i][0] + (int)a[i][1]*b[i][1] +
            (int)a[i][2]*b[i][2] + (int)a[i][3]*b[i][3];

不幸的是,直到 Cascade Lake 之前的英特尔 CPU 都没有此指令,因此有一个问题是使用以前的扩展名(例如 AVX-512BW)来模拟这个指令。 所以我的问题是:如何使这种模拟尽可能有效?

【问题讨论】:

【参考方案1】:

我认为这个问题没有一个正确答案。

一方面,使用 AVX-512BW 扩展对_mm512_dpbusd_epi32 的快速仿真可以看作:

inline __m512i _mm512_dpbusd_epi32_bw_fast(__m512i i32, __m512i u8, __m512i i8)

    __m512i i16 = _mm512_maddubs_epi16(u8, i8); //possible overflow of INT16.
    __m512i _1 = _mm512_set1_epi16(1);
    return _mm512_add_epi32(i32, _mm512_madd_epi16(i16, _1));

这个实现只使用了 3 条指令(而且它们都很快)。 但是由于_mm512_maddubs_epi16指令中INT16可能溢出,它会给出不正确的结果。

另一方面,正确的仿真看起来很糟糕,需要 14 条指令(其中一些指令非常慢):

inline __m512i _mm512_hadd_epi32(__m512i a, __m512i b)

    static const __m512i IDX0 = _mm512_setr_epi32(
        0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, 
        0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E);
    static const __m512i IDX1 = _mm512_setr_epi32(
        0x01, 0x03, 0x05, 0x07, 0x09, 0x0B, 0x0D, 0x0F, 
        0x11, 0x13, 0x15, 0x17, 0x19, 0x1B, 0x1D, 0x1F);
    __m512i ab0 = _mm512_permutex2var_epi32(a, IDX0, b);
    __m512i ab1 = _mm512_permutex2var_epi32(a, IDX1, b);
    return _mm512_add_epi32(ab0, ab1);


inline __m512i _mm512_dpbusd_epi32_bw_exact(__m512i i32, __m512i u8, __m512i i8)

    __m512i u8_i16lo = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(u8, 0));
    __m512i i8_i16lo = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(i8, 0));
    __m512i i32lo = _mm512_madd_epi16(u8_i16lo, i8_i16lo);
    __m512i u8_i16hi = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(u8, 1));
    __m512i i8_i16hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(i8, 1));
    __m512i i32hi = _mm512_madd_epi16(u8_i16hi, i8_i16hi);
    return _mm512_add_epi32(i32, _mm512_hadd_epi32(i32lo, i32hi));

【讨论】:

以上是关于_mm512_dpbusd_epi32 AVX-512VNNI 指令的 AVX-512BW 仿真的主要内容,如果未能解决你的问题,请参考以下文章

*_dpbusd_epi32 或 *_maddubs_epi16 在 ARM 上等效?

反转 __m512i 寄存器中的值

AVX512 缺少内在的 _mm512_round_ps

AVX-512:_mm512_load 与标准指针转换?

如何用 gcc 或 clang 模拟 _mm256_loadu_epi32?

_mm512_i64gather_pd() 的内存访问错误