对于纯 numpy 代码,使用 numba 的收益在哪里?
Posted
技术标签:
【中文标题】对于纯 numpy 代码,使用 numba 的收益在哪里?【英文标题】:Where are the gains using numba coming from for pure numpy code? 【发布时间】:2017-11-28 21:01:37 【问题描述】:我想了解使用 Numba 在 for 循环中加速纯 numpy
代码时的收益来自哪里。是否有任何分析工具可以让您查看jitted
函数?
演示代码(如下)只是使用非常基本的矩阵乘法来为计算机提供工作。观察到的收益来自:
-
更快的
loop
,
在编译过程中被jit
拦截的numpy
函数重铸,或者
jit
的开销更少,因为 numpy 通过包装函数将执行外包给低级库,例如 LINPACK
%matplotlib inline
import numpy as np
from numba import jit
import pandas as pd
#Dimensions of Matrices
i = 100
j = 100
def pure_python(N,i,j):
for n in range(N):
a = np.random.rand(i,j)
b = np.random.rand(i,j)
c = np.dot(a,b)
@jit(nopython=True)
def jit_python(N,i,j):
for n in range(N):
a = np.random.rand(i,j)
b = np.random.rand(i,j)
c = np.dot(a,b)
time_python = []
time_jit = []
N = [1,10,100,500,1000,2000]
for n in N:
time = %timeit -oq pure_python(n,i,j)
time_python.append(time.average)
time = %timeit -oq jit_python(n,i,j)
time_jit.append(time.average)
df = pd.DataFrame('pure_python' : time_python, 'jit_python' : time_jit, index=N)
df.index.name = 'Iterations'
df[["pure_python", "jit_python"]].plot()
生成以下图表。
【问题讨论】:
我认为 Numba 可以识别np.random.rand
和 np.dot
。 (如果没有,我认为它不会让您在 nopython 模式下使用它们。)
确实如此。根据文档,numba
支持它们。 numba.pydata.org/numba-doc/dev/reference/numpysupported.html。我主要好奇代码拦截是如何工作的,以及这是否是上例中的收益来源。
您能添加一些设置信息吗?在 Win 64、python 3.5、numba 0.33 上,我只有适度的加速(10-15%)
当然。我在 Linux Mint 18、Linux Kernel 4.4.0-45-generic、python 3.5、numba 0.30.1、Intel Xeon CPU E5-1620 @ 3.6Ghz x 4
据我所知,答案是 1) 和 2)。 numba
将函数编译为 c
代码。因此,它显着加快了循环解析,并以显着的python
开销加速了numpy
函数(通常通过剥离该开销并强制显式数据排序 - 即没有axis
关键字,没有einsum
,没有@大多数构造函数上的 987654341@ 参数(random.rand
是一个例外)...所有这些事情都可以在现在更快的 for
循环中显式完成)
【参考方案1】:
TL:DR 随机和循环得到加速,但矩阵乘法除了小矩阵大小之外没有。在较小的矩阵/循环大小下,似乎有可能与 python 开销有关的显着加速。在 N 较大时,矩阵乘法开始占主导地位,而 jit 的帮助不大
函数定义,为简单起见使用方阵。
from IPython.display import display
import numpy as np
from numba import jit
import pandas as pd
#Dimensions of Matrices
N = 1000
def py_rand(i, j):
a = np.random.rand(i, j)
jit_rand = jit(nopython=True)(py_rand)
def py_matmul(a, b):
c = np.dot(a, b)
jit_matmul = jit(nopython=True)(py_matmul)
def py_loop(N, val):
count = 0
for i in range(N):
count += val
jit_loop = jit(nopython=True)(py_loop)
def pure_python(N,i,j):
for n in range(N):
a = np.random.rand(i,j)
b = np.random.rand(i,j)
c = np.dot(a,a)
jit_func = jit(nopython=True)(pure_python)
时间:
df = pd.DataFrame(columns=['Func', 'jit', 'N', 'Time'])
def meantime(f, *args, **kwargs):
t = %timeit -oq -n5 f(*args, **kwargs)
return t.average
for N in [10, 100, 1000, 2000]:
a = np.random.randn(N, N)
b = np.random.randn(N, N)
df = df.append('Func': 'jit_rand', 'N': N, 'Time': meantime(jit_rand, N, N), ignore_index=True)
df = df.append('Func': 'py_rand', 'N': N, 'Time': meantime(py_rand, N, N), ignore_index=True)
df = df.append('Func': 'jit_matmul', 'N': N, 'Time': meantime(jit_matmul, a, b), ignore_index=True)
df = df.append('Func': 'py_matmul', 'N': N, 'Time': meantime(py_matmul, a, b), ignore_index=True)
df = df.append('Func': 'jit_loop', 'N': N, 'Time': meantime(jit_loop, N, 2.0), ignore_index=True)
df = df.append('Func': 'py_loop', 'N': N, 'Time': meantime(py_loop, N, 2.0), ignore_index=True)
df = df.append('Func': 'jit_func', 'N': N, 'Time': meantime(jit_func, 5, N, N), ignore_index=True)
df = df.append('Func': 'py_func', 'N': N, 'Time': meantime(pure_python, 5, N, N), ignore_index=True)
df['jit'] = df['Func'].str.contains('jit')
df['Func'] = df['Func'].apply(lambda s: s.split('_')[1])
df.set_index('Func')
display(df)
结果:
Func jit N Time
0 rand True 10 1.030686e-06
1 rand False 10 1.115149e-05
2 matmul True 10 2.250371e-06
3 matmul False 10 2.199343e-06
4 loop True 10 2.706000e-07
5 loop False 10 7.274286e-07
6 func True 10 1.217046e-05
7 func False 10 2.495837e-05
8 rand True 100 5.199217e-05
9 rand False 100 8.149794e-05
10 matmul True 100 7.848071e-05
11 matmul False 100 2.130794e-05
12 loop True 100 2.728571e-07
13 loop False 100 3.003743e-06
14 func True 100 6.739634e-04
15 func False 100 1.146594e-03
16 rand True 1000 5.644258e-03
17 rand False 1000 8.012790e-03
18 matmul True 1000 1.476098e-02
19 matmul False 1000 1.613211e-02
20 loop True 1000 2.846572e-07
21 loop False 1000 3.539849e-05
22 func True 1000 1.256926e-01
23 func False 1000 1.581177e-01
24 rand True 2000 2.061612e-02
25 rand False 2000 3.204709e-02
26 matmul True 2000 9.866484e-02
27 matmul False 2000 1.007234e-01
28 loop True 2000 3.011143e-07
29 loop False 2000 7.477454e-05
30 func True 2000 1.033560e+00
31 func False 2000 1.199969e+00
看起来 numba 正在优化循环,所以我不会费心将它包含在比较中
情节:
def jit_speedup(d):
py_time = d[d['jit'] == False]['Time'].mean()
jit_time = d[d['jit'] == True]['Time'].mean()
return py_time / jit_time
import seaborn as sns
result = df.groupby(['Func', 'N']).apply(jit_speedup).reset_index().rename(columns=0: 'Jit Speedup')
result = result[result['Func'] != 'loop']
sns.factorplot(data=result, x='N', y='Jit Speedup', hue='Func')
因此,对于 5 次重复的循环,jit 可以相当稳定地加快速度,直到矩阵乘法变得足够昂贵,以使其他开销相比之下变得微不足道。
【讨论】:
您可能对修复def py_loop():
代码感兴趣,因为原来的代码主要是一个O( 1 )
缩放(从概念上讲,可能是一个被忽视的快捷方式/屏蔽变量错误,独立于 N
),它使实验偏向于返回总是只是 ~ 163 - 294 [ns]
处理持续时间(这证实了只是“恒定”的普通呼叫签名处理开销 + JMP / RET 持续时间,不是任何N
-times 循环代码执行)。
不,你必须返回一个值 - 一个“那里”产生的值,否则 JIT 编译器分析看不出有任何明显的原因使硅循环如此多次通过“静默”代码,如果它什么都不返回... 更好地设计numba.jit(...)
test-payload 更仔细,否则你再次诉诸不测试任何合理的O /P.
我最初对此很担心,但由于没有一个函数具有恒定的时间缩放(它们都没有返回),我认为 numba jit 实际上并没有完全优化代码。我可能错了。我不想将时间与返回、类型转换等混为一谈。以上是关于对于纯 numpy 代码,使用 numba 的收益在哪里?的主要内容,如果未能解决你的问题,请参考以下文章