Python 实现中的 Strassen 算法错误
Posted
技术标签:
【中文标题】Python 实现中的 Strassen 算法错误【英文标题】:Strassen's Algorithm bug in python implementation 【发布时间】:2018-09-02 13:20:20 【问题描述】:通过 Strassen 算法和 Python 3 中的朴素嵌套 for 循环实现,我得到了不同的矩阵乘法输出。
代码:
def new_matrix(r, c):
"""Create a new matrix filled with zeros."""
matrix = [[0 for row in range(r)] for col in range(c)]
return matrix
def direct_multiply(x, y):
if len(x[0]) != len(y):
return "Multiplication is not possible!"
else:
p_matrix = new_matrix(len(x), len(y[0]))
for i in range(len(x)):
for j in range(len(y[0])):
for k in range(len(y)):
p_matrix[i][j] += x[i][k] * y[k][i]
return p_matrix
def split(matrix):
"""Split matrix into quarters."""
a = b = c = d = matrix
while len(a) > len(matrix)/2:
a = a[:len(a)//2]
b = b[:len(b)//2]
c = c[len(c)//2:]
d = d[len(d)//2:]
while len(a[0]) > len(matrix[0])//2:
for i in range(len(a[0])//2):
a[i] = a[i][:len(a[i])//2]
b[i] = b[i][len(b[i])//2:]
c[i] = c[i][:len(c[i])//2]
d[i] = d[i][len(d[i])//2:]
return a, b, c, d
def add_matrix(a, b):
if type(a) == int:
d = a + b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] + b[i][j])
d.append(c)
return d
def subtract_matrix(a, b):
if type(a) == int:
d = a - b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] - b[i][j])
d.append(c)
return d
def strassen(x, y, n):
# base case: 1x1 matrix
if n == 1:
z = [[0]]
z[0][0] = x[0][0] * y[0][0]
return z
else:
# split matrices into quarters
a, b, c, d = split(x)
e, f, g, h = split(y)
# p1 = a*(f-h)
p1 = strassen(a, subtract_matrix(f, h), n/2)
# p2 = (a+b)*h
p2 = strassen(add_matrix(a, b), h, n/2)
# p3 = (c+d)*e
p3 = strassen(add_matrix(c, d), e, n/2)
# p4 = d*(g-e)
p4 = strassen(d, subtract_matrix(g, e), n/2)
# p5 = (a+d)*(e+h)
p5 = strassen(add_matrix(a, d), add_matrix(e, h), n/2)
# p6 = (b-d)*(g+h)
p6 = strassen(subtract_matrix(b, d), add_matrix(g, h), n/2)
# p7 = (a-c)*(e+f)
p7 = strassen(subtract_matrix(a, c), add_matrix(e, f), n/2)
z11 = add_matrix(subtract_matrix(add_matrix(p5, p4), p2), p6)
z12 = add_matrix(p1, p2)
z21 = add_matrix(p3, p4)
z22 = add_matrix(subtract_matrix(subtract_matrix(p5, p3), p7), p1)
z = new_matrix(len(z11)*2, len(z11)*2)
for i in range(len(z11)):
for j in range(len(z11)):
z[i][j] = z11[i][j]
z[i][j+len(z11)] = z12[i][j]
z[i+len(z11)][j] = z21[i][j]
z[i+len(z11)][j+len(z11)] = z22[i][j]
return z
a = [[11,11,11,11],[22,22,22,22],[33,33,33,33],[44,44,44,44]]
b = [[101,181,119,113],[22,22,22,22],[33,33,33,33],[44,44,44,44]]
print(f"a = a")
print(f"b = b")
print(f"Using Strassen's algorithm:\na*b = strassen(a, b, 4)")
print(f"Using naive algorithm:\na*b = direct_multiply(a, b)")
输出:
$ python3 strassen.py
a = [[11, 11, 11, 11], [22, 22, 22, 22], [33, 33, 33, 33], [44, 44, 44, 44]]
b = [[101, 181, 119, 113], [22, 22, 22, 22], [33, 33, 33, 33], [44, 44, 44, 44]]
使用施特拉森算法:
a*b = [[2200, 3080, 2398, 2332], [4400, 6160, 4796, 4664], [6600, 9240, 7194, 6996], [8800, 12320, 9592, 9328]]] p>
使用朴素算法:
a*b = [[2200, 2200, 2200, 2200], [6160, 6160, 6160, 6160], [7194, 7194, 7194, 7194], [9328, 9328, 9328, 9328]]
谁能帮忙?
【问题讨论】:
您是否使用pdb
(或其他一些 Python 调试器)检查了代码,以找出结果与 Strassen 算法的预期(纸面)结果开始偏离的地方?
p_matrix[i][j] += x[i][k] * y[k][i]
应该是p_matrix[i][j] += x[i][k] * y[k][j]
【参考方案1】:
应该是
p_matrix[i][j] += x[i][k] * y[k][j]
在您的 direct_multiply
函数中。这样您就可以将一行中的k
th 元素与一列中的k
th 元素相乘,然后将其累加。就像你做矩阵乘法一样。
输出
a = [[11, 11, 11, 11], [22, 22, 22, 22], [33, 33, 33, 33], [44, 44, 44, 44]]
b = [[101, 181, 119, 113], [22, 22, 22, 22], [33, 33, 33, 33], [44, 44, 44, 44]]
Using Strassen's algorithm:
a*b = [[2200, 3080, 2398, 2332], [4400, 6160, 4796, 4664], [6600, 9240, 7194, 6996], [8800, 12320, 9592, 9328]]
Using naive algorithm:
a*b = [[2200, 3080, 2398, 2332], [4400, 6160, 4796, 4664], [6600, 9240, 7194, 6996], [8800, 12320, 9592, 9328]]
【讨论】:
以上是关于Python 实现中的 Strassen 算法错误的主要内容,如果未能解决你的问题,请参考以下文章