np.random.shuffle np.random.permutation的速度差异
Posted 爆米花好美啊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了np.random.shuffle np.random.permutation的速度差异相关的知识,希望对你有一定的参考价值。
现象
我们发现np.random.permutation要比np.random.shuffle快很多
x = np.random.rand(50000, 2)
# 933 µs
%timeit x.take(np.random.permutation(x.shape[0]), axis=0)
# 1.41 ms
%timeit x[np.random.permutation(x.shape[0])]
# 1.41 ms
%timeit np.random.permutation(x)
# 46.3 ms
%timeit np.random.shuffle(x)
933 µs ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.41 ms ± 5.87 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.41 ms ± 4.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
46.3 ms ± 413 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
原因
np.random.shuffle 是原地修改数组,因此需要开辟一个buffer作为交换空间buf = np.empty_like(x[0])
,而且需要for循环反复交换
# https://github.com/numpy/numpy/blob/18f2385b29bdd62701a1a82d7bf33fd87430a05e/numpy/random/mtrand/mtrand.pyx#L4841
# Shuffling and permutations:
def shuffle(self, object x):
"""
shuffle(x)
Modify a sequence in-place by shuffling its contents.
This function only shuffles the array along the first axis of a
multi-dimensional array. The order of sub-arrays is changed but
their contents remains the same.
"""
# ......
elif isinstance(x, np.ndarray) and x.ndim > 1 and x.size:
# Multidimensional ndarrays require a bounce buffer.
buf = np.empty_like(x[0])
with self.lock:
for i in reversed(range(1, n)):
j = rk_interval(i, self.internal_state)
buf[...] = x[j]
x[j] = x[i]
x[i] = buf
np.random.permutation 是返回一个新数组,因此只用shuffle idx即可,比较快
# https://github.com/numpy/numpy/blob/18f2385b29bdd62701a1a82d7bf33fd87430a05e/numpy/random/mtrand/mtrand.pyx#L4917
def permutation(self, object x):
"""
permutation(x)
Randomly permute a sequence, or return a permuted range.
If `x` is a multi-dimensional array, it is only shuffled along its
first index.
"""
# ......
# Shuffle index array, dtype to ensure fast path
idx = np.arange(arr.shape[0], dtype=np.intp)
self.shuffle(idx)
return arr[idx]
以上是关于np.random.shuffle np.random.permutation的速度差异的主要内容,如果未能解决你的问题,请参考以下文章
np.random.shuffle np.random.permutation的速度差异
NP:建立可视化输入的二次函数数据点集np.linspace+np.random.shuffle+np.random.normal
numpy random.shuffle()和random.permutation()