Strassen优化矩阵乘法(复杂度O(n^lg7))
Posted Saurus
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Strassen优化矩阵乘法(复杂度O(n^lg7))相关的知识,希望对你有一定的参考价值。
按照算法导论写的
还没有测试复杂度到底怎么样
不过这个真的很卡内存,挖个坑,以后写空间优化
还有Matthew Anderson, Siddharth Barman写了一个关于矩阵乘法的论文
《The Coppersmith-Winograd Matrix Multiplication Algorithm》
提出了矩阵乘法的O(n^2.37)算法,有时间再膜吧orz
#include <iostream> #include <cstring> #include <cstdio> #include <iomanip> using namespace std; const int maxn = 30; struct Matrix { double v[maxn][maxn]; int n, m; Matrix() { memset(v, 0, sizeof(v));} Matrix operator +(const Matrix& B) { Matrix C; C.n = n; C.m = m; for(int i = 0; i < n; i++) for(int j = 0; j < n; j++) C.v[i][j] = v[i][j] + B.v[i][j]; return C; } Matrix operator -(const Matrix& B) { Matrix C; C.n = n; C.m = m; for(int i = 0; i < n; i++) for(int j = 0; j < n; j++) C.v[i][j] = v[i][j] - B.v[i][j]; return C; } Matrix operator *(const Matrix &B) { Matrix C; C.n = n; C.m = B.m; for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) { if(v[i][j] == 0) continue; //矩阵常数优化 for(int k = 0; k < m; k++) C.v[i][k] += v[i][j]*B.v[j][k]; } return C; } void prepare() //将矩阵转换成2^k的形式,便于分治 { int _n = 1; while(_n < n) _n <<= 1; while(_n < m) _n <<= 1; for(int i = 0; i < n; i++) for(int j = m; j < _n; j++) v[i][j] = 0; for(int i = n; i < _n; i++) for(int j = 0; j < _n; j++) v[i][j] = 0; n = m = _n; } void read() { cin>>n>>m; for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) cin>>v[i][j]; } Matrix get(int i1, int j1, int i2, int j2) { Matrix C; C.n = i2-i1+1; C.m = j2-j1+1; for(int i = i1-1; i < i2; i++) for(int j = j1-1; j < j2; j++) C.v[i-i1+1][j-j1+1] = v[i][j]; return C; } void give(Matrix &B, int i1, int j1, int i2, int j2) { for(int i = i1-1; i < i2; i++) for(int j = j1-1; j < j2; j++) v[i][j] = B.v[i-i1+1][j-j1+1]; } void print() { for(int i = 0; i < n; i++) { for(int j = 0; j < m; j++) cout<<setw(6)<<v[i][j]; cout<<endl; } } }A, B; Matrix Strassen(Matrix &X, Matrix &Y) //分治+利用多次矩阵相加代替矩阵相乘优化,复杂度O(n^2.81) { if(X.n == 1) return X*Y; int n = X.n; Matrix A[2][2], B[2][2], S[10], P[7]; A[0][0] = X.get(1, 1, n/2, n/2); A[0][1] = X.get(1, n/2+1, n/2, n); A[1][0] = X.get(n/2+1, 1, n, n/2); A[1][1] = X.get(n/2+1, n/2+1, n, n); B[0][0] = Y.get(1, 1, n/2, n/2); B[0][1] = Y.get(1, n/2+1, n/2, n); B[1][0] = Y.get(n/2+1, 1, n, n/2); B[1][1] = Y.get(n/2+1, n/2+1, n, n); //for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) A[i][j].print(); cout<<endl; } //for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) B[i][j].print(); cout<<endl; } S[0] = B[0][1] - B[1][1]; S[1] = A[0][0] + A[0][1]; S[2] = A[1][0] + A[1][1]; S[3] = B[1][0] - B[0][0]; S[4] = A[0][0] + A[1][1]; S[5] = B[0][0] + B[1][1]; S[6] = A[0][1] - A[1][1]; S[7] = B[1][0] + B[1][1]; S[8] = A[0][0] - A[1][0]; S[9] = B[0][0] + B[0][1]; P[0] = Strassen(A[0][0], S[0]); P[1] = Strassen(S[1], B[1][1]); P[2] = Strassen(S[2], B[0][0]); P[3] = Strassen(A[1][1], S[3]); P[4] = Strassen(S[4], S[5]); P[5] = Strassen(S[6], S[7]); P[6] = Strassen(S[8], S[9]); //for(int i = 0; i < 7; i++) P[i].print(); cout<<endl; B[0][0] = P[4] + P[3] - P[1] + P[5]; B[0][1] = P[0] + P[1]; B[1][0] = P[2] + P[3]; B[1][1] = P[4] + P[0] - P[2] - P[6]; //for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) B[i][j].print(); } X.give(B[0][0], 1, 1, n/2, n/2); X.give(B[0][1], 1, n/2+1, n/2, n); X.give(B[1][0], n/2+1, 1, n, n/2); X.give(B[1][1], n/2+1, n/2+1, n, n); return X; } int main() { Matrix C; A.read(); B.read(); int n = A.n, m = B.m; A.prepare(); B.prepare(); C = Strassen(A, B); C.n = n; C.m = m; C.print(); }
以上是关于Strassen优化矩阵乘法(复杂度O(n^lg7))的主要内容,如果未能解决你的问题,请参考以下文章