Strassen算法

Posted Zpfly_2008

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Strassen算法相关的知识,希望对你有一定的参考价值。

如题,该算法是来自德国的牛逼的数学家strassen搞出来的,因为把n*n矩阵之间的乘法复杂度降低到n^(lg7)(lg的底是2),一开始想当然地认为朴素的做法是n^3,哪里还能有复杂度更低的做法,但是牛逼的strassen先生简直刷新了我的线性代数观和算法观

思路:基本的思路网上有,此处不再赘述,下面说说怎么实现strassen算法(数据处理)

m*n的矩阵和n*p的矩阵相乘,得到m*p的矩阵,因为每次要二分,遇到奇数的做法就是在行尾(列尾)加全零行(列),因为加入全零行(列)是不会影响就算结果的,从而使得可以二分。我为了省事直接在输入两个矩阵后就直接把它们统一扩展成了2^y*2^y的矩阵,其中,y=2^「lg(Max{m,p})」  (「」是向上取整,输入法里没找到合适的符号....;另外,Max(m,n)和Max(n,p)是相等的,取哪个都一样)

代码:

# include <iostream.h>
# include "..\Sort\IO_tools.cpp"
void strassen(int** &C,int** &A,int** &B,int arow,int acol,int brow,int bcol,int size);
void stra_plus(int** &C,int** &A,int** &B,int crow,int ccol,int arow,int acol,int brow,int bcol,int size,int symbol);
void main(){
   int** A=NULL;int** B=NULL;int** C=NULL;
   int A_row,A_col,B_row,B_col,size;
   cout<<"size of A<row,col>:"<<endl;
   cin>>A_row>>A_col;
   size=input2A(A,A_row,A_col);//万能的传引用,绝对正确!
   cout<<"size of B<row,col>:"<<endl;
   cin>>B_row>>B_col;
   input2A(B,B_row,B_col);//万能的传引用,绝对正确!
   strassen(C,A,B,0,0,0,0,size);
   output2A(C,A_row,B_col);
   //stra_plus(C,A,B,0,0,0,0,0,0,size,1);
   //output2A(C,size,size);
}
void strassen(int** &C,int** &A,int** &B,int arow,int acol,int brow,int bcol,int size){
	C=(int**)new int* [size];
	for(int i=0;i<size;i++){
	   C[i]=new int[size];
	}
	/*
	对于size>1的要进一步拆分(其实strassen算法递归计算时要申请这么多内存,size不够大时反而降低了效率,
	故而size达到下限时可以采用朴素的矩阵乘法计算方法而不必继续调用strassen算法,此处出于偷懒就省点事儿把size下限设为1)
	*/
	if(size>1){
		//S(1-10)初始化
	   int** S1=NULL;int** S2=NULL;int** S3=NULL;int** S4=NULL;int** S5=NULL;int** S6=NULL;int** S7=NULL;int** S8=NULL;int** S9=NULL;int** S10=NULL;
	   stra_plus(S1,B,B,0,0,brow,bcol+size/2,brow+size/2,bcol+size/2,size/2,-1);
	   stra_plus(S2,A,A,0,0,arow,acol,arow,acol+size/2,size/2,1);
	   stra_plus(S3,A,A,0,0,arow+size/2,acol,arow+size/2,acol+size/2,size/2,1);
	   stra_plus(S4,B,B,0,0,brow+size/2,bcol,brow,bcol,size/2,-1);
	   stra_plus(S5,A,A,0,0,arow,acol,arow+size/2,acol+size/2,size/2,1);
	   stra_plus(S6,B,B,0,0,brow,bcol,brow+size/2,bcol+size/2,size/2,1);
	   stra_plus(S7,A,A,0,0,arow,acol+size/2,arow+size/2,acol+size/2,size/2,-1);
	   stra_plus(S8,B,B,0,0,brow+size/2,bcol,brow+size/2,bcol+size/2,size/2,1);
	   stra_plus(S9,A,A,0,0,arow,acol,arow+size/2,acol,size/2,-1);
	   stra_plus(S10,B,B,0,0,brow,bcol,brow,bcol+size/2,size/2,1);
       //P(1-7)初始化
	   int** P1=NULL;int** P2=NULL;int** P3=NULL;int** P4=NULL;int** P5=NULL;int** P6=NULL;int** P7=NULL;
	   strassen(P1,A,S1,arow,acol,0,0,size/2);
	   strassen(P2,S2,B,0,0,brow+size/2,bcol+size/2,size/2);
	   strassen(P3,S3,B,0,0,brow,bcol,size/2);
	   strassen(P4,A,S4,arow+size/2,acol+size/2,0,0,size/2);
	   strassen(P5,S5,S6,0,0,0,0,size/2);
	   strassen(P6,S7,S8,0,0,0,0,size/2);
	   strassen(P7,S9,S10,0,0,0,0,size/2);
	   //计算结果C(依次是C11,C12,C21,C22)
	   stra_plus(C,P4,P5,0,0,0,0,0,0,size/2,1);stra_plus(C,C,P2,0,0,0,0,0,0,size/2,-1);stra_plus(C,C,P6,0,0,0,0,0,0,size/2,1);
	   stra_plus(C,P1,P2,0,size/2,0,0,0,0,size/2,1);
	   stra_plus(C,P3,P4,size/2,0,0,0,0,0,size/2,1);
	   stra_plus(C,P5,P1,size/2,size/2,0,0,0,0,size/2,1);stra_plus(C,C,P3,size/2,size/2,size/2,size/2,0,0,size/2,-1);stra_plus(C,C,P7,size/2,size/2,size/2,size/2,0,0,size/2,-1);
	}
	/*到达下限*/
	else{
	    C[0][0]=A[arow][acol]*B[brow][bcol];
	}
}
//参与运算的是A,B,C的size*size的(子)矩阵,<arow,acol>是A的参与运算的子矩阵的左上角坐标,<brow,bcol>同理,C是保存结果的
void stra_plus(int** &C,int** &A,int** &B,int crow,int ccol,int arow,int acol,int brow,int bcol,int size,int symbol){
	if(C==NULL){
       C=(int**)new int* [size];
	   for(int i=0;i<size;i++){
	      C[i]=new int[size];
	   }
	}
    for(int i=0;i<size;i++){
		for(int j=0;j<size;j++){
			  C[i+crow][j+ccol]=A[i+arow][j+acol]+symbol*B[i+brow][j+bcol];
		}
	}
}

  

# include <iostream.h>
# include <stdlib.h>
# include <math.h>
int get_Upper_2Pow(int row,int col);
void inputA(int A[],int n){
	int i=n;
	cout<<"Input Array:";
	while(i--){
	   cin>>A[n-i-1];
	}
}
void outputA(int A[],int n){
   int i=n;
   cout<<"output Array:";
   while(i--){
      cout<<A[n-i-1]<<" ";   
   }
   cout<<endl;
}
int input2A(int** &A,int row,int col){
   int size=get_Upper_2Pow(row,col);
   A=(int**)new int* [size];
   for(int i=0;i<size;i++){
	   A[i]=new int[size];
   }
   cout<<"Input 2th Array:"<<endl;
   for(int r=0;r<size;r++){
	   for(int c=0;c<size;c++){
	       A[r][c]=0;
	   }
   }
   for(r=0;r<row;r++){
	   for(int c=0;c<col;c++){
	      cin>>A[r][c];
	   }
   }
   return size;
}
void output2A(int** A,int row,int col){
	cout<<"output:"<<endl;
	for(int r=0;r<row;r++){
		for(int c=0;c<col;c++){
		    cout<<A[r][c]<<"  ";
		}
		cout<<endl;
	}
}
int get_Upper_2Pow(int row,int col){
  for(int i=0;pow(2,i)<row||pow(2,i)<col;i++);
  return (int)pow(2,i);
}

  

以上是关于Strassen算法的主要内容,如果未能解决你的问题,请参考以下文章

击败 Strassen 算法的算法

Strassen算法

Strassen算法不是最快的吗?

实施 Strassen 矩阵乘法算法的问题

矩阵乘法的Strassen算法及时间复杂度

使用 Strassen 算法将 2 个数字与 n 位相乘的算法