尽管 Python 简单,但循环效率低下
Posted
技术标签:
【中文标题】尽管 Python 简单,但循环效率低下【英文标题】:Python inefficient loop despite its simplicity 【发布时间】:2021-07-06 11:13:04 【问题描述】:我尝试运行这个小代码,它只需要随机点(这里是 50k,接近我实际拥有的点)并返回随机选择的每个点的第 10 个最近点。
但不幸的是,这是(真的!)很长,因为肯定是循环。
由于我是“代码优化”方面的新手,有什么技巧可以让这种糊状变得更快吗? (在 Python 规模上更快,我知道我不是在 C++ 中编码)。
这是一个可重现的示例,其数据大小接近我所拥有的:
import time
import numpy as np
from numpy import random
from scipy.spatial import distance
# USEFUL FUNCTION
start_time = time.time()
def closest_node(node, nodes):
nodes = np.asarray(nodes)
deltas = nodes - node
dist_2 = np.einsum("ij,ij->i", deltas, deltas)
ndx = dist_2.argsort()
return data[ndx[:10]]
# REPRODUCIBLE DATA
mean = np.array([0.0, 0.0, 0.0])
cov = np.array([[1.0, -0.5, 0.8], [-0.5, 1.1, 0.0], [0.8, 0.0, 1.0]])
data = np.random.multivariate_normal(mean, cov, 500000)
# START RUNNING
points = data[np.random.choice(data.shape[0], int(np.round(0.1 * len(data), 0)))]
print(len(points))
for w in points:
closest_node(w, data)
print("--- %s seconds ---" % (time.time() - start_time))
【问题讨论】:
您好,欢迎来到 SO!第一个问题很好。它包含得很好,您有一个工作示例。但是,我不清楚你到底在问什么。是为什么代码很慢? 您的代码结合了两种算法复杂性:第一个是循环w in points
,显然在 len(points) 中是线性的。但是第二个是dist_2.argsort
,通常是 O(N log(N)) ... 但在你的情况下 N=len(data)=500000;这是昂贵的部分!
【参考方案1】:
在 500000 个元素的数组上运行 argsort
的每个循环所花费的时间是巨大的。我能想到的唯一改进是使用可以返回最小 10 个元素而无需对整个数组进行完全排序的东西。
A fast way to find the largest N elements in an numpy array
所以不是
ndx = dist_2.argsort()
return data[ndx[:10]]
应该是
ndx = np.argpartition(dist_2, 10)[:10]
return data[ndx[:10]]
我只在 500 分上进行了基准测试,因为在我的 PC 上运行它已经花费了相当长的时间。
N=500
Using argsort: 25.625439167022705 seconds
Using argpartition: 6.637120485305786 seconds
【讨论】:
【参考方案2】:您最好通过分析器分析最慢的点:How do I find out what parts of my code are inefficient in Python
乍一看可能会发生的一件事是,您应该尽可能多地移动到循环之外。如果要通过 np.asarray() 转换点,最好在循环之前对所有点执行一次,然后在函数中使用结果,而不是在每次循环运行时执行 np.asarray()。
【讨论】:
np.asarray
如果它的参数已经是 numpy.ndarray
应该是空操作,这里似乎就是这种情况以上是关于尽管 Python 简单,但循环效率低下的主要内容,如果未能解决你的问题,请参考以下文章