如何使用 SSE 更有效地将 A*B^T 或 A^T*B^T(T 表示转置)矩阵相乘?

Posted

技术标签:

【中文标题】如何使用 SSE 更有效地将 A*B^T 或 A^T*B^T(T 表示转置)矩阵相乘?【英文标题】:How do I more efficiently multiply A*B^T or A^T*B^T (T for transpose) matrices using SSE? 【发布时间】:2013-05-14 15:27:05 【问题描述】:

我一直在用这个打败自己。我有一个基于 SSE 的算法,用于将矩阵 A 乘以矩阵 B。我还需要实现 A、B 或两者都被转置的操作。我做了一个简单的实现,下面表示的 4x4 矩阵代码(我认为这是非常标准的 SSE 操作),但 A*B^T 操作所需的时间大约是 A*B 的两倍。 ATLAS 实现为 A*B 返回相似的值,并且乘以转置的结果几乎相同,这表明有一种有效的方法可以做到这一点。

MM-乘法:

m1 = (mat1.m_>>2)<<2;
n2 = (mat2.n_>>2)<<2;
n  = (mat1.n_>>2)<<2;

for (k=0; k<n; k+=4) 
  for (i=0; i<m1; i+=4) 
    // fetch: get 4x4 matrix from mat1
    // row-major storage, so get 4 rows
    Float* a0 = mat1.el_[i]+k;
    Float* a1 = mat1.el_[i+1]+k;
    Float* a2 = mat1.el_[i+2]+k;
    Float* a3 = mat1.el_[i+3]+k;

    for (j=0; j<n2; j+=4) 
      // fetch: get 4x4 matrix from mat2
      // row-major storage, so get 4 rows
      Float* b0 = mat2.el_[k]+j;
      Float* b1 = mat2.el_[k+1]+j;
      Float* b2 = mat2.el_[k+2]+j;
      Float* b3 = mat2.el_[k+3]+j;

      __m128 b0r = _mm_loadu_ps(b0);
      __m128 b1r = _mm_loadu_ps(b1);
      __m128 b2r = _mm_loadu_ps(b2);
      __m128 b3r = _mm_loadu_ps(b3);

        // first row of result += first row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a0+0), b0r), _mm_mul_ps(_mm_load_ps1(a0+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a0+2), b2r), _mm_mul_ps(_mm_load_ps1(a0+3), b3r));
        Float* c0 = this->el_[i]+j;
        _mm_storeu_ps(c0, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c0)));
      

       // second row of result += second row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a1+0), b0r), _mm_mul_ps(_mm_load_ps1(a1+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a1+2), b2r), _mm_mul_ps(_mm_load_ps1(a1+3), b3r));
        Float* c1 = this->el_[i+1]+j;
        _mm_storeu_ps(c1, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c1)));
      

       // third row of result += third row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a2+0), b0r), _mm_mul_ps(_mm_load_ps1(a2+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a2+2), b2r), _mm_mul_ps(_mm_load_ps1(a2+3), b3r));
        Float* c2 = this->el_[i+2]+j;
        _mm_storeu_ps(c2, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c2)));
      

       // fourth row of result += fourth row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a3+0), b0r), _mm_mul_ps(_mm_load_ps1(a3+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a3+2), b2r), _mm_mul_ps(_mm_load_ps1(a3+3), b3r));
        Float* c3 = this->el_[i+3]+j;
        _mm_storeu_ps(c3, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c3)));
      
  
// Code omitted to handle remaining rows and columns

对于 MT 乘法(矩阵乘以转置矩阵),我使用以下命令将 b0r 存储到 b3r 并适当地更改循环变量:

__m128 b0r = _mm_set_ps(b3[0], b2[0], b1[0], b0[0]);
__m128 b1r = _mm_set_ps(b3[1], b2[1], b1[1], b0[1]);
__m128 b2r = _mm_set_ps(b3[2], b2[2], b1[2], b0[2]);
__m128 b3r = _mm_set_ps(b3[3], b2[3], b1[3], b0[3]);

我怀疑速度变慢的部分原因是一次拉入一行与每次必须存储 4 个值才能获取该列之间的差异,但我觉得这是另一种解决方法,拉入行B 的列,然后乘以 As 的列,只会将成本转移到存储 4 列结果。

我还尝试将 B 的行作为行拉入,然后使用 _MM_TRANSPOSE4_PS(b0r, b1r, b2r, b3r); 进行转置(我认为该宏中可能有一些额外的优化),但没有真正的改进。

从表面上看,我觉得这应该更快......所涉及的点积将是一行接一行,这似乎本质上更有效,但试图直接做点积只会导致不得不做存储结果也是一样的。

我在这里错过了什么?

补充:为了澄清,我试图不转置矩阵。我更愿意沿着它们进行迭代。据我所知,问题在于 _mm_set_ps 命令比 _mm_load_ps 慢得多。

我还尝试了一种变体,我存储了 A 矩阵的 4 行,然后用 4 个乘法指令和 3 个hadds 替换了包含 1 个加载、4 个乘法和 2 个加法的 4 个大括号段,但很少利用。时间保持不变(是的,我尝试使用调试语句来验证代码在我的测试编译中是否发生了更改。当然,在分析之前删除了该调试语句):

      // first row of result += first row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a0r, b0r), _mm_mul_ps(a0r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a0r, b2r), _mm_mul_ps(a0r, b3r));
      Float* c0 = this->el_[i]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    

     // second row of result += second row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a1r, b0r), _mm_mul_ps(a1r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a1r, b2r), _mm_mul_ps(a1r, b3r));
      Float* c0 = this->el_[i+1]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    

     // third row of result += third row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a2r, b0r), _mm_mul_ps(a2r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a2r, b2r), _mm_mul_ps(a2r, b3r));
      Float* c0 = this->el_[i+2]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    

     // fourth row of result += fourth row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a3r, b0r), _mm_mul_ps(a3r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a3r, b2r), _mm_mul_ps(a3r, b3r));
      Float* c0 = this->el_[i+3]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    

更新: 对,将a0r 的行加载到a3r 移动到花括号中,以避免寄存器抖动也失败了。

【问题讨论】:

您是否意识到您不需要真正转置矩阵而只需以不同的方式访问它? 看起来您实际上是在转置矩阵?当然会慢一些。 实际上,我正在尝试通过以不同顺序访问项目来就地进行操作。我尝试了就地转置并得到了大致相同的时间。 不要使用 hadd 指令。坚持垂直添加。 实际上,我认为我错了。看我的回答。 【参考方案1】:

我认为这是水平添加有用的少数情况。您想要 C = AB^T 但 B 没有作为转置存储在内存中。那就是问题所在。它的存储类似于 AoS 而不是 SoA。在这种情况下,我认为采用 B 的转置并进行垂直添加比使用水平添加要慢。至少对于矩阵向量Efficient 4x4 matrix vector multiplication with SSE: horizontal add and dot product - what's the point? 是这样。在下面的代码中,函数m4x4 是非 SSE 4x4 矩阵乘积,m4x4_vec 使用 SSE, m4x4T 在没有 SSE 的情况下执行 C=AB^T,而 m4x4T_vec 执行 C=AB^T 使用 SSE。最后一个是我认为你想要的那个。

注意:对于较大的矩阵,我不会使用此方法。在这种情况下,首先进行转置并使用垂直添加会更快(使用 SSE/AVX,您会做一些更复杂的事情,您可以使用 SSE/AVX 宽度转置条带)。这是因为转置为 O(n^2),矩阵乘积为 O(n^3),因此对于大型矩阵,转置无关紧要。但是,对于 4x4,转置非常重要,因此水平相加胜出。

编辑: 我误解了你想要什么。你想要 C = (AB)^T。这应该和 (AB) 一样快,并且代码几乎相同,您基本上只需交换 A 和 B 的角色。 我们可以这样写数学:

C = A*B in Einstein notation is C_i,j = A_i,k * B_k,j.  
Since (A*B)^T = B^T*A^T we can write 
C = (A*B)^T in Einstein notation is C_i,j = B^T_i,k * A^T_k,j = A_j,k * B_k,i

如果你比较这两者,唯一的变化是我们交换了 j 和 i 的角色。我在这个答案的末尾放了一些代码来做到这一点。

#include "stdio.h"
#include <nmmintrin.h>    

void m4x4(const float *A, const float *B, float *C) 
    for(int i=0; i<4; i++) 
        for(int j=0; j<4; j++) 
            float sum = 0.0f;
            for(int k=0; k<4; k++) 
                sum += A[i*4+k]*B[k*4+j];
            
            C[i*4 + j] = sum;
        
    


void m4x4T(const float *A, const float *B, float *C) 
    for(int i=0; i<4; i++) 
        for(int j=0; j<4; j++) 
            float sum = 0.0f;
            for(int k=0; k<4; k++) 
                sum += A[i*4+k]*B[j*4+k];
            
            C[i*4 + j] = sum;
        
    


void m4x4_vec(const float *A, const float *B, float *C) 
    __m128 Brow[4], Mrow[4];
    for(int i=0; i<4; i++) 
        Brow[i] = _mm_load_ps(&B[4*i]);
    

    for(int i=0; i<4; i++) 
        Mrow[i] = _mm_set1_ps(0.0f);
        for(int j=0; j<4; j++) 
            __m128 a = _mm_set1_ps(A[4*i +j]);
            Mrow[i] = _mm_add_ps(Mrow[i], _mm_mul_ps(a, Brow[j]));
        
    
    for(int i=0; i<4; i++) 
        _mm_store_ps(&C[4*i], Mrow[i]);
    


void m4x4T_vec(const float *A, const float *B, float *C) 
    __m128 Arow[4], Brow[4], Mrow[4];
    for(int i=0; i<4; i++) 
        Arow[i] = _mm_load_ps(&A[4*i]);
        Brow[i] = _mm_load_ps(&B[4*i]);
    

    for(int i=0; i<4; i++) 
        __m128 prod[4];
        for(int j=0; j<4; j++) 
            prod[j] =  _mm_mul_ps(Arow[i], Brow[j]);
        
        Mrow[i] = _mm_hadd_ps(_mm_hadd_ps(prod[0], prod[1]), _mm_hadd_ps(prod[2], prod[3]));    
    
    for(int i=0; i<4; i++) 
        _mm_store_ps(&C[4*i], Mrow[i]);
    



float compare_4x4(const float* A, const float*B) 
    float diff = 0.0f;
    for(int i=0; i<4; i++) 
        for(int j=0; j<4; j++) 
            diff += A[i*4 +j] - B[i*4+j];
            printf("A %f, B %f\n", A[i*4 +j], B[i*4 +j]);
        
    
    return diff;    


int main() 
    float *A = (float*)_mm_malloc(sizeof(float)*16,16);
    float *B = (float*)_mm_malloc(sizeof(float)*16,16);
    float *C1 = (float*)_mm_malloc(sizeof(float)*16,16);
    float *C2 = (float*)_mm_malloc(sizeof(float)*16,16);

    for(int i=0; i<4; i++) 
        for(int j=0; j<4; j++) 
            A[i*4 +j] = i*4+j;
            B[i*4 +j] = i*4+j;
            C1[i*4 +j] = 0.0f;
            C2[i*4 +j] = 0.0f;
        
    
    m4x4T(A, B, C1);
    m4x4T_vec(A, B, C2);
    printf("compare %f\n", compare_4x4(C1,C2));


编辑:

这是执行 C = (AB)^T 的标量和 SSE 函数。它们应该和它们的 AB 版本一样快。

void m4x4TT(const float *A, const float *B, float *C) 
    for(int i=0; i<4; i++) 
        for(int j=0; j<4; j++) 
            float sum = 0.0f;
            for(int k=0; k<4; k++) 
                sum += A[j*4+k]*B[k*4+i];
            
            C[i*4 + j] = sum;
        
    


void m4x4TT_vec(const float *A, const float *B, float *C) 
    __m128 Arow[4], Crow[4];
    for(int i=0; i<4; i++) 
        Arow[i] = _mm_load_ps(&A[4*i]);
    

    for(int i=0; i<4; i++) 
        Crow[i] = _mm_set1_ps(0.0f);
        for(int j=0; j<4; j++) 
            __m128 a = _mm_set1_ps(B[4*i +j]);
            Crow[i] = _mm_add_ps(Crow[i], _mm_mul_ps(a, Arow[j]));
        
    

    for(int i=0; i<4; i++) 
        _mm_store_ps(&C[4*i], Crow[i]);
    

【讨论】:

我目前正在探索进行转置。我的方法是遍历 4x4 矩阵,拉入 4 行,使用_MM_TRANSPOSE4_PS 转置它们,然后将它们存储在转置位置,然后处理外围行,然后处理单个值(很像乘法算法我上面用过)。感谢您提供示例代码。 太棒了!请让我知道你发现了什么。我的猜测是,使用 _MM_TRANSPOSE4_PS(它会做一堆洗牌)和垂直添加会比只使用 hadd 慢,但我可能是错的。 顺便说一句,我没有尝试对上面的代码进行进一步优化,例如循环展开,如果有帮助,我不会感到惊讶。 首先进行转置确实可以将速度提高到大致相同的值!即使我将调用替换为对 mulMM 的调用(!),仍然会追踪一些奇怪的东西,其中转置乘以转置需要更长的时间,但这是肯定的进步。谢谢。 我想我误解了你的意思。你写了 AB^T。所以我以为你的意思是 B 的转置。但你真的是说 (AB)^T 吗?在这种情况下,我认为您可以使用垂直添加而不进行任何转置,因此它应该与 A*B 一样快。【参考方案2】:

一些可能有帮助的建议:

不要使用未对齐的内存(那些 _mm_loadu* 很慢)。 您没有按顺序访问内存,这会杀死缓存。在实际访问该内存之前尝试转置矩阵,这将使 CPU 尽可能多地获取和使用缓存。这样就不需要下面的__m128 b0r = _mm_set_ps(b3[0], b2[0], b1[0], b0[0]); // and b1r, etc..。这个想法是按顺序获取整个 4 个组件。如果您需要在调用 SSE 代码之前重新组织内存,请这样做。 您正在内部循环中加载:_mm_load_ps1(a0+0) (对于 a1、a2 和 a3 相同),但对于内部循环中的所有迭代都是恒定的。您可以将这些值加载到外部并节省一些周期。密切关注您可以从以前的迭代中重复使用的内容。 个人资料。使用 Intel VTune 或类似的东西,它会告诉你瓶颈在哪里。

【讨论】:

一个迂腐的观点。 _mm_loadu_ps 内在函数只在未对齐的内存上很慢。如果你在对齐的内存上使用它,它基本上和 _mm_load_ps 一样快。 由于我是内存对齐的新手,在我进入并尝试更改现有代码之前,像这样对齐值如何影响内存消耗? 不确定你的意思,但是 SSE 代码要求数据是 16 字节对齐的,否则 CPU 需要花费更多的周期来修复解析内存地址。只是将矩阵分配为 16 个字节对齐的问题。 我误解了情况。长话短说,我认为这是留下空白,因为我为相同的花车写了错误的数字。

以上是关于如何使用 SSE 更有效地将 A*B^T 或 A^T*B^T(T 表示转置)矩阵相乘?的主要内容,如果未能解决你的问题,请参考以下文章

如何更有效地将词汇存储在数组中?

如何有效地将具有一定周期性的列表拆分为多个列表?

有效地将聚合列激发到 Set

SSE:服务器发送事件,使用长链接进行通讯

有效地将切片插入另一个切片

SSE 比较内在 - 如何从比较中获得 1 或 0?