Strassen算法不是最快的吗?

Posted

技术标签:

【中文标题】Strassen算法不是最快的吗?【英文标题】:Strassen algorithm not the fastest? 【发布时间】:2011-06-06 13:52:15 【问题描述】:

我从某个地方复制了 strassen 的算法,然后执行了它。这是输出

n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms

其中strassen1 是动态方法,strassen2 用于缓存,classical 是旧矩阵乘法。这意味着我们古老而简单的经典是最好的。这是真的还是我在某个地方错了?这是Java代码。

import java.util.Random;

class TestIntMatrixMultiplication 

    public static void main (String...args) throws Exception 
        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
        final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
        final Random random = new Random(seed);

        int[][] a, b, c;

        a = new int[n][n];
        b = new int[n][n];
        c = new int[n][n];

        for(int i=0; i<n; i++) 
            for(int j=0; j<n; j++) 
                a[i][j] = random.nextInt(100);
                b[i][j] = random.nextInt(100);
            
        



        System.out.println("n = " + n);

        if (a.length < 64) 
            System.out.println("A");
            dumpMatrix(a);
            System.out.println("B");
            dumpMatrix(b);
            System.out.println("classic");
            Classical.mult(c, a, b);
            dumpMatrix(c);
            System.out.println("strassen");
            strassen2.mult(c, a, b);
            dumpMatrix(c);

            return;
        

        for (int i = 0; i <3; ++i) 
            timeMultiplies1(a, b, c);
            if (n <= 256)
                timeMultiplies2( a, b, c);
            timeMultiplies3( a, b, c);
        
    

    static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) 
        final long start = System.currentTimeMillis();
        Classical.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("classical took " + (finish - start) + "ms");
    
    static void timeMultiplies2(int[][] a, int[][] b, int[][] c) 
        final long start = System.currentTimeMillis();
        strassen1.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen 1 took " + (finish - start) + "ms");
    
    static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) 
        final long start = System.currentTimeMillis();
        strassen2.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen2 took " + (finish - start) + "ms");
    

    static void dumpMatrix (int[][] m) 
        for (int[] row : m) 
            System.out.print("[\t");
            for (int val : row) 
                System.out.print(val);
                System.out.print('\t');
            
            System.out.println(']');
        
    


class strassen1

    public String getName () 
        return "Strassen(dynamic)";
    

    public static int[][] mult (int[][] c, int[][] a, int[][] b) 
        return strassenMatrixMultiplication(a, b);
    

    public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) 
        int n = A.length;

        int [][] result = new int[n][n];

        if(n == 1) 
            result[0][0] = A[0][0] * B[0][0];
         else 
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
            int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
            int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
            int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
            int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
            int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
            int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));

            int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
            int [][] C12 = addMatrices(P3, P5);
            int [][] C21 = addMatrices(P2, P4);
            int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);

            copySubArray(C11, result, 0 , 0);
            copySubArray(C12, result, 0 , n/2);
            copySubArray(C21, result, n/2, 0);
            copySubArray(C22, result, n/2, n/2);
        

        return result;
    

    public static int [][] addMatrices(int [][] A, int [][] B) 
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
        result[i][j] = A[i][j] + B[i][j];

        return result;
    

    public static int [][] subtractMatrices(int [][] A, int [][] B) 
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                result[i][j] = A[i][j] - B[i][j];

        return result;
    

    public static void divideArray(int[][] parent, int[][] child, int iB, int jB) 
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                child[i1][j1] = parent[i2][j2];
    

    public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) 
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                parent[i2][j2] = child[i1][j1];
    

class strassen2

    public String getName () 
        return "Strassen(cached)";
    

    static int [][] p1;
    static int [][] p2;
    static int [][] p3;
    static int [][] p4;
    static int [][] p5;
    static int [][] p6;
    static int [][] p7;
    static int [][] t0;
    static int [][] t1;

    public static int[][] mult (int[][] c, int[][] a, int[][] b) 
        final int n = c.length;

        if (p1 == null || p1.length < n) 
            p1 = new int[n/2][n-1];
            p2 = new int[n/2][n-1];
            p3 = new int[n/2][n-1];
            p4 = new int[n/2][n-1];
            p5 = new int[n/2][n-1];
            p6 = new int[n/2][n-1];
            p7 = new int[n/2][n-1];
            t0 = new int[n/2][n-1];
            t1 = new int[n/2][n-1];
        

        mult(c, a, b, 0, 0, n, 0);

        return c;
    

    public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) 
        if(n == 1) 
            c[i0][j0] = a[i0][j0] * b[i0][j0];
         else 
            final int nBy2 = n/2;

            final int i1 = i0 + nBy2;
            final int j1 = j0 + nBy2;

            // offset applied to 'p' j index so recursive calls don't overwrite data
            final int jp0 = offs;
            final int jp1 = nBy2 + offs;

            // P1 <- (A11 + A22)(B11 + B22)
            //  T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
                
            

            mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P2 <- (A21 + A22)B11
            //  T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0];
                    
            

            mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P3 <- A11(B12 - B22)
            //  T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
                
            

            mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P4 <- A22(B21 - B11)
            //  T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
                
            

            mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P5 <- (A11 + A12) B22
            //  T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j1];
                
            

            mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P6 <- (A21 - A11)(B11 - B12)
            //  T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
                
            

            mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P7 <- (A12 - A22)(B21 + B22)
            //  T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
                
            

            mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // combine
            for (int i = 0; i < nBy2; ++i) 
                for (int j = 0; j < nBy2; ++j) 
                    // C11 = P1 + P4 - P5 + P7;
                    c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
                    // C12 = P3 + P5;
                    c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
                    // C21 = P2 + P4;
                    c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
                    // C22 = P1 + P3 - P2 + P6;
                    c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
                
            
        
    

    void dumpInternal () 
        System.out.println("P1");
        TestIntMatrixMultiplication.dumpMatrix(p1);
        System.out.println("P2");
        TestIntMatrixMultiplication.dumpMatrix(p2);
        System.out.println("P3");
        TestIntMatrixMultiplication.dumpMatrix(p3);
        System.out.println("P4");
        TestIntMatrixMultiplication.dumpMatrix(p4);
        System.out.println("P5");
        TestIntMatrixMultiplication.dumpMatrix(p5);
        System.out.println("P6");
        TestIntMatrixMultiplication.dumpMatrix(p6);
        System.out.println("P7");
        TestIntMatrixMultiplication.dumpMatrix(p7);
        System.out.println("T0");
        TestIntMatrixMultiplication.dumpMatrix(t0);
        System.out.println("T1");
        TestIntMatrixMultiplication.dumpMatrix(t1);
    



class Classical
    public String getName () 
        return "classic";
    

    public static int[][] mult (int[][] c, int[][] a, int[][] b) 
        int n = a.length;

        for(int i=0; i<n; i++) 
            final int[] a_i = a[i];
            final int[] c_i = c[i];

            for(int j=0; j<n; j++) 
                int sum = 0;

                for(int k=0; k<n; k++) 
                    sum += a_i[k] * b[k][j];
                

                c_i[j] = sum;
            
        

        return c;
    

【问题讨论】:

你在某处错了。至少假设“n = 256”涵盖了所有 n,假设“来自某处”是最好的实现,并且可能在很多其他方面。 首先,这不是在 java 中对算法进行基准测试的正确方法。首先在这里阅读this 并检查这是否不会改变结果。 【参考方案1】:

我看到的问题:

1)您的 Strassen 乘法一直在动态分配内存。这会影响性能。

2) 您的 Strassen 乘法应该切换到小尺寸的传统乘法,而不是一直递归下去(尽管这种优化会使您的测试无效)。

3)您的矩阵大小可能太小而看不出差异。

您应该对几种不同的尺寸进行比较。大概是256、512、1024、2048、4096、8192……然后画出时代,看趋势。如果矩阵大小都是 2 的幂,您可能会希望矩阵大小为对数。

Strassen 仅对大 N 更快。多大将在很大程度上取决于实现。你为经典所做的只是一个基本的实现,在现代机器上也不是最优的。

【讨论】:

我大体上同意。但是 2) 是任何分而治之算法最重要的方面之一(例如,将简单的合并排序与对其叶子使用快速排序/选择排序的合并排序进行比较)所以我不明白为什么这会使测试无效?跨度> 如果测试是为了显示 Strassen 更快,那么您希望始终使用它。只有在您知道交叉点之后,才应根据大小进行切换。不过,对于大 N,它应该仍然更快,所以从这个意义上说,它不会使任何东西失效。 当然,切换不会使结果无效。如果 A 和 B 的混合比纯 A 更快,那么 B 是一种改进。考虑三个元素的合并排序与选择排序。当然选择排序会获胜:两者都进行 3 次比较,但选择排序更简单。这并不意味着合并排序在大型数组上并不好。【参考方案2】:

抛开实施问题不谈,我认为您误解了算法的性能。就像 phkahler 所说,您对算法性能的期望有点偏离。分治算法适用于大输入,因为它们递归地将问题分解为可以更快解决的子问题。

但是,与此拆分操作相关的开销可能会导致算法对于小型甚至中型输入的运行速度(有时会慢得多)。通常,像 Strassen 这样的算法的理论分析将包括所谓的“断点”计算。这是输入大小,其中拆分的开销比简单的技术更可取。

您的代码需要包含对在断点处切换到简单技术的输入大小的检查。

【讨论】:

【参考方案3】:

写下 Strassen 算法对 2 x 2 矩阵的作用。计算操作。这个数字绝对是荒谬的。将 Strassen 方法用于 2x2 矩阵是愚蠢的。对于 3 x 3 或 4 x 4 的矩阵也是如此,而且可能还有很大的提升空间。

【讨论】:

以上是关于Strassen算法不是最快的吗?的主要内容,如果未能解决你的问题,请参考以下文章

矩阵乘法 - 分而治与斯特拉森,分而治之更快?

检查两个文件是不是相等的最快哈希算法是啥?

检测图中是不是存在负循环的最快算法

判断一个数组是不是至少有一个重复项的最快算法

什么是疯狂大整数除法的最快算法?

最快的 Perlin-Like 3D 噪声算法?