使用带有 numpy 矩阵的 Strassen 算法的输出矩阵不正确

Posted

技术标签:

【中文标题】使用带有 numpy 矩阵的 Strassen 算法的输出矩阵不正确【英文标题】:Incorrect output matrix using Strassen's algorithm with numpy matrices 【发布时间】:2017-11-05 20:14:04 【问题描述】:

我正在尝试使用 Python 3 和 numpy 矩阵来实现 CLRS 中描述的 Strassen 矩阵乘法算法。

问题是输出矩阵 C 作为零矩阵返回,而不是正确的乘积。我不确定为什么我的实现不起作用,但怀疑这与每次递归调用创建 C 矩阵有关。对于我做错了什么以及如何解决它,我将不胜感激。

谢谢!

import numpy as np

def strassen(A,B):
    n = A.shape[0]
    C = np.zeros((n*n), dtype=np.int).reshape(n,n)
    if n == 1:
        C[0][0] = A[0][0] * B[0][0]

    else:
        k = int(n/2) 

        A11,A21,A12,A22 = A[:k,:k], A[k:, :k], A[:k, k:], A[k:, k:]
        B11,B21,B12,B22 = B[:k,:k], B[k:, :k], B[:k, k:], B[k:, k:]
        C11,C21,C12,C22 = C[:k,:k], C[k:, :k], C[:k, k:], C[k:, k:]

        S1 = B12 - B22
        S2 = A11 + A12
        S3 = A21 + A22
        S4 = B21 - B11
        S5 = A11 + A22
        S6 = B11 + B22
        S7 = A12 - A22
        S8 = B21 + B22
        S9 = A11 - A21
        S10= B11 + B12

        P1 = strassen(A11, S1)
        P2 = strassen(S2, B22)
        P3 = strassen(S3, B11)
        P4 = strassen(A22, S4)
        P5 = strassen(S5, S6)
        P6 = strassen(S7, S8)
        P7 = strassen(S9, S10)

        C11 = P5 + P4 - P2 + P6
        C12 = P1 + P2
        C21 = P3 + P4
        C22 = P5 + P1 - P3 - P7


    return C

【问题讨论】:

好吧,在 C = np.zeros((n*n), dtype=np.int).reshape(n,n) 行中创建它之后,您永远不会修改 martix C C的切片不引用原件吗?我的印象是 numpy 切片是视图。如何修改初始矩阵? 删除 C11,C21,C12,C22 = C[:k,:k] ...,将 C11 = P5 + P4 - P2 + P6 更改为 C[:k,:k] = P5 + P4 - P2 + P6,并对以下行进行类似操作 C11 只是一个名字。第一次使用它来命名C 的切片,第二次使用它来命名表达式的结果。就像shell,1)ln A C 2)ln -f B C——你认为文件A现在有文件B的内容吗? Python 没有C 意义上的变量,它有在计算表达式时实例化的对象。你可以给它一个名字,很多名字,没有名字。如果对象是 mutable(google 'python mutable immutable objects'),您可以使用索引 (a[3]=0) 切片 (P[:k,:k]=...) 或点 (a=MyClass(); a.xy=(2,4)) 对其进行修改。很多时候,您评估一个表达式的副作用并且没有命名创建的对象,例如,plt.plot(x,y) 返回一个您通常丢弃的复杂数据结构,因为您想要的只是屏幕上显示的图形。 【参考方案1】:

好的,我通过简单地用新值更新切片 C[:k,:k] 而不是创建新变量 C11、C12 ..ect 来实现它。 因为这样做会创建一个新矩阵,而不是对原始矩阵 C 的引用。

【讨论】:

以上是关于使用带有 numpy 矩阵的 Strassen 算法的输出矩阵不正确的主要内容,如果未能解决你的问题,请参考以下文章

实施 Strassen 矩阵乘法算法的问题

矩阵乘法的Strassen算法及时间复杂度

Strassen算法

Algorithms - Strassen's algorithm for matrix multiplication 矩阵乘法 Strassen 算法

Algorithms - Strassen's algorithm for matrix multiplication 矩阵乘法 Strassen 算法

矩阵乘法 strassen