Cython 优化 numpy 数组求和的关键部分
Posted
技术标签:
【中文标题】Cython 优化 numpy 数组求和的关键部分【英文标题】:Cython optimize the critical part of a numpy array summation 【发布时间】:2014-02-06 22:57:13 【问题描述】:令 L 是一个列表 L = [A_1, A_2, ..., A_n]
,每个 A_i
都是长度为 1024 的 numpy.int32
数组。
(大多数时候 1000
经过一些分析,我发现最耗时的操作是求和:
def summation():
# L is a global variable, modified outside of this function
b = numpy.zeros(1024, numpy.int32)
for a in L:
b += a
return b
PS:我不认为我可以定义大小为 1024 x n
的二维数组,因为 n
不固定:一些元素被动态删除/添加到 L,所以 len(L) = n
在期间可以在 1000 和 4000 之间变化运行时间。
我可以通过使用 Cython 获得显着的改进吗?
如果是这样,我应该如何对这个小功能进行cython-recode(我不应该添加一些cdef
打字吗?)
或者你能看到一些其他可能的改进吗?
【问题讨论】:
【参考方案1】:这是 Cython 代码,确保 L 中的每个数组都是 C_CONTIGUOUS:
import cython
import numpy as np
cimport numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def sum_list(list a):
cdef int* x
cdef int* b
cdef int i, j
cdef int count
count = len(a[0])
res = np.zeros_like(a[0])
b = <int *>((<np.ndarray>res).data)
for j in range(len(a)):
x = <int *>((<np.ndarray>a[j]).data)
for i in range(count):
b[i] += x[i]
return res
一台我的 PC 大约快 4 倍。
【讨论】:
非常感谢!这对我帮助很大 !如果 a[0]、a[1] 等是int16 numpy arrays
,我希望结果 res
仍然是 int32 numpy array
,你知道我该如何修改此代码吗?
如果输入数组是int16
,而输出仍然是int32
,我将cdef int* x
替换为cdef short* x
,将x = <int *>((<np.ndarray>a[j]).data)
替换为<short *>
。你认为这是最好的方法吗?以上是关于Cython 优化 numpy 数组求和的关键部分的主要内容,如果未能解决你的问题,请参考以下文章
Cython:从参考获得时,Numpy 数组缺少两个第一个元素