如何加速这个用 Python 编写的程序?
Posted
技术标签:
【中文标题】如何加速这个用 Python 编写的程序?【英文标题】:How can I speed up this program written in Python? 【发布时间】:2021-11-08 16:04:04 【问题描述】:以下程序是Python中Marching Square问题的解决方案:
from typing import List
def GetCaseId(Point_A_data: float, Point_B_data: float,
Point_C_data: float, Point_D_data: float,
threshold):
caseId = 0
if (Point_A_data >= threshold):
caseId |= 1
if (Point_B_data >= threshold):
caseId |= 2
if (Point_C_data >= threshold):
caseId |= 4
if (Point_D_data >= threshold):
caseId |= 8
return caseId
def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
a: float, b: float, c: float, d: float,
threshold: float):
lines = []
caseId = GetCaseId(a, b, c, d, threshold)
if caseId in (0, 15):
return []
if caseId in (1, 14, 10):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_B[1]
qX = Point_D[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (2, 13, 5):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = Point_C[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (3, 12):
pX = Point_A[0]
pY = (Point_A[1] + Point_D[1]) / 2
qX = Point_C[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (4, 11, 10):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_D[1]
qX = Point_B[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (6, 9):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = (Point_C[0] + Point_D[0]) / 2
qY = Point_C[1]
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (7, 8, 5):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_C[1]
qX = Point_A[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
return lines
def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
linesList = []
Height = len(y_int_list) # rows
Width = len(x_int_list) # cols
if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):
for j in range(Height - 1): # rows
for i in range(Width - 1): # cols
point_a_data_float = data_2d_list[j + 1][i]
point_b_data_float = data_2d_list[j + 1][i + 1]
point_c_data_float = data_2d_list[j][i + 1]
point_d_data_float = data_2d_list[j][i]
point_A = [x_int_list[i], y_int_list[j + 1]]
point_B = [x_int_list[i + 1], y_int_list[j + 1]]
point_C = [x_int_list[i + 1], y_int_list[j]]
point_D = [x_int_list[i], y_int_list[j]]
for threshold_item in threshold_list:
list = GetLines(point_A, point_B, point_C, point_D,
point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
threshold_item)
linesList = linesList + list
else:
raise AssertionError
return [linesList]
此源代码的问题是 - 生成输出需要很长时间。
即使用以下驱动程序:
import drawSvg as draw_svg
N_int = 800
N2_float = N_int / 8
x_int_vector = [i for i in range(N_int)]
y_int_vector = [i for i in range(N_int)]
matrix_256x256 = [[(math.sin(i / N2_float) * math.sin(j / N2_float)) for i in range(N_int)] for j in range(N_int)]
fill = "#2591a3"
drawing = draw_svg.Drawing(N_int, N_int, displayInline=False)
threshold_float_list = [0.2, 0.4, 0.6, 0.8]
collection = marching_square(x_int_vector, y_int_vector, matrix_256x256, threshold_float_list)
for line_set in collection:
for line in line_set:
drawing.append(draw_svg.Line(line[0], line[1], line[2], line[3], stroke='red'))
# END of line
# END of line_set
drawing.saveSvg('example.svg')
代码在实际使用中变得非常缓慢。
如何加快代码速度?
注意marching_square()
的签名不得更改。
【问题讨论】:
工作代码的性能改进建议发布在codereview.stackexchange.com/tour 首先,您对其进行分析以查看其慢的地方。 @AKX,在GetLine()
和GetCaseId()
。
【参考方案1】:
获得了约 10 倍的加速
-
删除了最大瓶颈的扩展列表(使用this trick to concatenate list of lists)
将
numba
应用于GetCaseId
,这是第二个瓶颈
from typing import List
import numba
import functools
import operator
@numba.jit(nopython=True)
def GetCaseId(Point_A_data: float, Point_B_data: float,
Point_C_data: float, Point_D_data: float,
threshold):
caseId = 0
if (Point_A_data >= threshold):
caseId |= 1
if (Point_B_data >= threshold):
caseId |= 2
if (Point_C_data >= threshold):
caseId |= 4
if (Point_D_data >= threshold):
caseId |= 8
return caseId
def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
a: float, b: float, c: float, d: float,
threshold: float):
lines = []
caseId = GetCaseId(a, b, c, d, threshold)
if caseId in (0, 15):
return None
if caseId in (1, 14, 10):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_B[1]
qX = Point_D[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (2, 13, 5):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = Point_C[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (3, 12):
pX = Point_A[0]
pY = (Point_A[1] + Point_D[1]) / 2
qX = Point_C[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
if caseId in (4, 11, 10):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_D[1]
qX = Point_B[0]
qY = (Point_B[1] + Point_C[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (6, 9):
pX = (Point_A[0] + Point_B[0]) / 2
pY = Point_A[1]
qX = (Point_C[0] + Point_D[0]) / 2
qY = Point_C[1]
line = (pX, pY, qX, qY)
lines.append(line)
elif caseId in (7, 8, 5):
pX = (Point_C[0] + Point_D[0]) / 2
pY = Point_C[1]
qX = Point_A[0]
qY = (Point_A[1] + Point_D[1]) / 2
line = (pX, pY, qX, qY)
lines.append(line)
return lines
def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
linesList = []
Height = len(y_int_list) # rows
Width = len(x_int_list) # cols
if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):
for j in range(Height - 1): # rows
for i in range(Width - 1): # cols
point_a_data_float = data_2d_list[j + 1][i]
point_b_data_float = data_2d_list[j + 1][i + 1]
point_c_data_float = data_2d_list[j][i + 1]
point_d_data_float = data_2d_list[j][i]
point_A = [x_int_list[i], y_int_list[j + 1]]
point_B = [x_int_list[i + 1], y_int_list[j + 1]]
point_C = [x_int_list[i + 1], y_int_list[j]]
point_D = [x_int_list[i], y_int_list[j]]
for threshold_item in threshold_list:
list = GetLines(point_A, point_B, point_C, point_D,
point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
threshold_item)
if list:
linesList.append(list)
else:
raise AssertionError
return functools.reduce(operator.iconcat, linesList, [])
【讨论】:
以上是关于如何加速这个用 Python 编写的程序?的主要内容,如果未能解决你的问题,请参考以下文章