Strassen优化矩阵乘法(复杂度O(n^lg7))

Posted Saurus

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Strassen优化矩阵乘法(复杂度O(n^lg7))相关的知识,希望对你有一定的参考价值。

按照算法导论写的

还没有测试复杂度到底怎么样

不过这个真的很卡内存,挖个坑,以后写空间优化

还有Matthew Anderson, Siddharth Barman写了一个关于矩阵乘法的论文

《The Coppersmith-Winograd Matrix Multiplication Algorithm》

提出了矩阵乘法的O(n^2.37)算法,有时间再膜吧orz

#include <iostream>
#include <cstring>
#include <cstdio>
#include <iomanip>
using namespace std;
const int maxn = 30;
struct Matrix
{
    double v[maxn][maxn];
    int n, m;
    Matrix() { memset(v, 0, sizeof(v));}
    Matrix operator +(const Matrix& B)
    {
        Matrix C; C.n = n; C.m = m;
        for(int i = 0; i < n; i++)
            for(int j = 0; j < n; j++)
                C.v[i][j] = v[i][j] + B.v[i][j];
        return C;
    }
    Matrix operator -(const Matrix& B)
    {
        Matrix C; C.n = n; C.m = m;
        for(int i = 0; i < n; i++)
            for(int j = 0; j < n; j++)
                C.v[i][j] = v[i][j] - B.v[i][j];
        return C;
    }
    Matrix operator *(const Matrix &B)
    {
        Matrix C; C.n = n; C.m = B.m;
        for(int i = 0; i < n; i++)
            for(int j = 0; j < m; j++)
            {
                if(v[i][j] == 0) continue; //矩阵常数优化
                for(int k = 0; k < m; k++)
                    C.v[i][k] += v[i][j]*B.v[j][k];
            }
        return C;
    }
    void prepare()  //将矩阵转换成2^k的形式,便于分治
    {
        int _n = 1;
        while(_n < n) _n <<= 1;
        while(_n < m) _n <<= 1;
        for(int i = 0; i < n; i++)
            for(int j = m; j < _n; j++)
                v[i][j] = 0;
        for(int i = n; i < _n; i++)
            for(int j = 0; j < _n; j++)
                v[i][j] = 0;
        n = m = _n;
    }
    void read()
    {
        cin>>n>>m;
        for(int i = 0; i < n; i++)
            for(int j = 0; j < m; j++)
                cin>>v[i][j];
    }
    Matrix get(int i1, int j1, int i2, int j2)
    {
        Matrix C; C.n = i2-i1+1; C.m = j2-j1+1;
        for(int i = i1-1; i < i2; i++)
            for(int j = j1-1; j < j2; j++)
                C.v[i-i1+1][j-j1+1] = v[i][j];
        return C;
    }
    void give(Matrix &B, int i1, int j1, int i2, int j2)
    {
        for(int i = i1-1; i < i2; i++)
            for(int j = j1-1; j < j2; j++)
                v[i][j] = B.v[i-i1+1][j-j1+1];
    }
    void print()
    {
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < m; j++)
                cout<<setw(6)<<v[i][j];
            cout<<endl;
        }

    }
}A, B;

Matrix Strassen(Matrix &X, Matrix &Y)  //分治+利用多次矩阵相加代替矩阵相乘优化,复杂度O(n^2.81)
{
    if(X.n == 1) return X*Y;
    int n = X.n;
    Matrix A[2][2], B[2][2], S[10], P[7];
    A[0][0] = X.get(1, 1, n/2, n/2);   A[0][1] = X.get(1, n/2+1, n/2, n);
    A[1][0] = X.get(n/2+1, 1, n, n/2); A[1][1] = X.get(n/2+1, n/2+1, n, n);
    B[0][0] = Y.get(1, 1, n/2, n/2);   B[0][1] = Y.get(1, n/2+1, n/2, n);
    B[1][0] = Y.get(n/2+1, 1, n, n/2); B[1][1] = Y.get(n/2+1, n/2+1, n, n);
    //for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) A[i][j].print(); cout<<endl; }
    //for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) B[i][j].print(); cout<<endl; }
    S[0] = B[0][1] - B[1][1]; S[1] = A[0][0] + A[0][1];
    S[2] = A[1][0] + A[1][1]; S[3] = B[1][0] - B[0][0]; S[4] = A[0][0] + A[1][1];
    S[5] = B[0][0] + B[1][1]; S[6] = A[0][1] - A[1][1];
    S[7] = B[1][0] + B[1][1]; S[8] = A[0][0] - A[1][0]; S[9] = B[0][0] + B[0][1];
    P[0] = Strassen(A[0][0], S[0]); P[1] = Strassen(S[1], B[1][1]);
    P[2] = Strassen(S[2], B[0][0]); P[3] = Strassen(A[1][1], S[3]);
    P[4] = Strassen(S[4], S[5]);    P[5] = Strassen(S[6], S[7]);    P[6] = Strassen(S[8], S[9]);
    //for(int i = 0; i < 7; i++) P[i].print(); cout<<endl;
    B[0][0] = P[4] + P[3] - P[1] + P[5];    B[0][1] = P[0] + P[1];
    B[1][0] = P[2] + P[3];                  B[1][1] = P[4] + P[0] - P[2] - P[6];
    //for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) B[i][j].print();  }
    X.give(B[0][0], 1, 1, n/2, n/2);    X.give(B[0][1], 1, n/2+1, n/2, n);
    X.give(B[1][0], n/2+1, 1, n, n/2);  X.give(B[1][1], n/2+1, n/2+1, n, n);
    return X;
}



int main()
{
    Matrix C;
    A.read(); B.read();
    int n = A.n, m = B.m;
    A.prepare(); B.prepare();
    C = Strassen(A, B); C.n = n; C.m = m; C.print();
}

 

以上是关于Strassen优化矩阵乘法(复杂度O(n^lg7))的主要内容,如果未能解决你的问题,请参考以下文章

矩阵乘法的Strassen算法及时间复杂度

Strassen algorithm(O(n^lg7))

关于strassen矩阵乘法的矩阵大小不是2^k的形式时,时间复杂度是否还是比朴素算法好的看法

Strassen算法及其python实现

矩阵乘法 strassen

Strassen矩阵乘法之思考