将 SSE 矩阵向量乘法代码转换为 AVX

Posted

技术标签:

【中文标题】将 SSE 矩阵向量乘法代码转换为 AVX【英文标题】:Convert SSE matrix-vector multiplication code to AVX 【发布时间】:2015-11-21 16:44:53 【问题描述】:

我正在尝试将我的 SSE 函数转换为 AVX。该函数执行向量矩阵乘法,这是我的工作 SSE 代码:

void multiply_matrix_by_vector_SSE(float* m, float* v, float* result, unsigned const int vector_dims)

    size_t i, j;
    for (i = 0; i < vector_dims; ++i)
    
        __m128 acc = _mm_setzero_ps();
        for (j = 0; j < vector_dims; j += 4)
        
            __m128 vec = _mm_load_ps(&v[j]);
            __m128 mat = _mm_load_ps(&m[j + vector_dims * i]);
            //acc = _mm_add_ps(acc, _mm_mul_ps(mat, vec));
            acc = _mm_fmadd_ps(mat, vec, acc);
        
        acc = _mm_hadd_ps(acc, acc);
        acc = _mm_hadd_ps(acc, acc);
        _mm_store_ss(&result[i], acc);
    

这就是我对 AVX 的看法:

void multiply_matrix_by_vector_AVX(float* m, float* v, float* result, unsigned const int vector_dims)

    size_t i, j;

    for (i = 0; i < vector_dims; ++i)
    
        __m256 acc = _mm256_setzero_ps();
        for (j = 0; j < vector_dims; j += 8)
        
            __m256 vec = _mm256_load_ps(&v[j]);
            __m256 mat = _mm256_load_ps(&m[j + vector_dims * i]);
            acc = _mm256_fmadd_ps(mat, vec, acc);
        
        acc = _mm256_hadd_ps(acc, acc);
        acc = _mm256_hadd_ps(acc, acc);
        acc = _mm256_hadd_ps(acc, acc);
        acc = _mm256_hadd_ps(acc, acc);

        _mm256_store_ps(&result[i], acc);
    

但是,AVX 代码崩溃 (Access violation reading location 0xFFFFFFFFFFFFFFFF)。


谁能帮我让我的 AVX 功能正常工作?

PS:我在函数中传递的矩阵和向量的大小始终是 8 的倍数。此外,我传递给我的 SSE 函数的数组是 16 位对齐的 (__declspec(align(16))float* = generate_matrix(256);),我传递给我的数组AVX 函数是 32 位对齐的 (__declspec(align(32))float* = generate_matrix(256););

【问题讨论】:

【参考方案1】:

不幸的是,使用这样的水平添加并不能轻易扩展到 256 位,因为指令(和大多数其他指令)是“单向的”——它的行为就像两个 haddps 并行,一个在上半部分,一个在上半部分下半部分,没有混合,所以下半部分和上半部分不会加在一起。

当然,它仍然不是打包结果,并且打包存储有一个对齐存储写入某个未对齐的地址并且会失败(该错误有点奇怪,但无论如何)。

无论如何,让我们修复水平总和:(未测试)

// this part still works
acc = _mm256_hadd_ps(acc, acc);
acc = _mm256_hadd_ps(acc, acc);
// this is new
__m128 acc1 = _mm256_extractf128_ps(acc, 0);
__m128 acc2 = _mm256_extractf128_ps(acc, 1);
acc1 = _mm_add_ss(acc1, acc2);
// do scalar store, obviously
_mm_store_ss(&result[i], acc1);

顺便说一句,内循环需要 10 个独立的链(和 10 个累加器)才能最大限度地提高 Haswell 的吞吐量。

【讨论】:

以上是关于将 SSE 矩阵向量乘法代码转换为 AVX的主要内容,如果未能解决你的问题,请参考以下文章

AVX/SSE 将浮点符号掩码转换为 __m128i

AVX 内在澄清,4x4 矩阵乘法奇数

使用 SSE 错误 __m128 到 *float 转换的矩阵乘法?

使用 AVX 的平铺矩阵乘法

用于灰度到 ARGB 转换的 C++ SSE2 或 AVX2 内在函数

如何使用 avx 指令将 float 向量转换为 short int?