如何加速这个用 Python 编写的程序?

Posted

技术标签:

【中文标题】如何加速这个用 Python 编写的程序?【英文标题】:How can I speed up this program written in Python? 【发布时间】:2021-11-08 16:04:04 【问题描述】:

以下程序是Python中Marching Square问题的解决方案:

from typing import List

def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return []

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    linesList = linesList + list
    else:
        raise AssertionError

    return [linesList]

此源代码的问题是 - 生成输出需要很长时间。

即使用以下驱动程序:

import drawSvg as draw_svg

N_int = 800
N2_float = N_int / 8
x_int_vector = [i for i in range(N_int)]
y_int_vector = [i for i in range(N_int)]

matrix_256x256 = [[(math.sin(i / N2_float) * math.sin(j / N2_float)) for i in range(N_int)] for j in range(N_int)]

fill = "#2591a3"
drawing = draw_svg.Drawing(N_int, N_int, displayInline=False)

threshold_float_list = [0.2, 0.4, 0.6, 0.8]
collection = marching_square(x_int_vector, y_int_vector, matrix_256x256, threshold_float_list)
for line_set in collection:
    for line in line_set:
        drawing.append(draw_svg.Line(line[0], line[1], line[2], line[3], stroke='red'))
     # END of line
# END of line_set
drawing.saveSvg('example.svg') 

代码在实际使用中变得非常缓慢。

如何加快代码速度?

注意marching_square()的签名不得更改。

【问题讨论】:

工作代码的性能改进建议发布在codereview.stackexchange.com/tour 首先,您对其进行分析以查看其慢的地方。 @AKX,在GetLine()GetCaseId() 【参考方案1】:

获得了约 10 倍的加速

    删除了最大瓶颈的扩展列表(使用this trick to concatenate list of lists) 将numba 应用于GetCaseId,这是第二个瓶颈
from typing import List
import numba
import functools
import operator

@numba.jit(nopython=True)
def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return None

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    if list:
                        linesList.append(list)

    else:
        raise AssertionError

    return functools.reduce(operator.iconcat, linesList, [])

【讨论】:

以上是关于如何加速这个用 Python 编写的程序?的主要内容,如果未能解决你的问题,请参考以下文章

8个用Python进行机器学习建模项目的实用建议,让新手小白精准避坑

如何让Python程序轻松加速,正确方法详解

10 个用纯 Javascript 实现的好用插件

怎么用VB6.0编写病毒(要摸版)

分享几个用 Python 给图片添加水印的方法,简单实用

使用 C++ 加速 Python