索引为零维 cp.array 的 cp.array 慢切片(基于 cp.argmin 结果)
Posted
技术标签:
【中文标题】索引为零维 cp.array 的 cp.array 慢切片(基于 cp.argmin 结果)【英文标题】:Slow slicing of cp.array with index being zero dimensional cp.array (based on cp.argmin result) 【发布时间】:2021-01-20 11:00:43 【问题描述】:我有一些代码,我需要根据应用于较小 cp.array 的 cp.argmin 的结果对较大的 cp.array 进行切片。 (参见下面的最小代码示例)
问题是,cp.argmin 返回一个零维 cp.array,而使用 :
运算符进行切片显然需要整数。
import time
import cupy as cp
original = cp.empty((10000, 10000))
nrows, ncols = 1000, 1000
to_modify = cp.empty((nrows, ncols))
start_time = time.time()
for i in range(10000):
argmin = cp.argmin(to_modify)
argmin = int(argmin)
row_idx, col_idx = (argmin // ncols, argmin % ncols)
sliced = original[row_idx : row_idx + nrows, col_idx : col_idx + ncols]
to_modify += sliced
print(time.time() - start_time)
当我分析上面的代码时(我使用 py-spy),最慢的行(大约 90% 的时间)是转换为 argmin 的 int,但如果我删除它,sliced = original[ ... ]
行将成为最慢的行,因为演员阵容似乎隐含地发生了。
有没有办法以高效的方式解决我的问题,避免切片时对 :
运算符的需求?
【问题讨论】:
【参考方案1】:不幸的是,我没有找到使用 cupy 的可行解决方案。 相反,我开始使用 numba。虽然 numba 需要以编写 cuda 内核的形式进行更多的手动工作,但它也提供了更多的控制。 使用 numba.cuda.device_array()、numba.cuda.to_device() 和 numba.cuda.copy_to_host() 还可以控制数组在 cpu 和 gpu 之间来回复制的时间。 最难的部分是实现 argmin,这需要一个 reduce 操作:
@cuda.jit
def find_argmin(data, argmin, tmp_min, tmp_idx):
x = cuda.grid(1)
shape_x, shape_y = data.shape[0], data.shape[1]
num_items = shape_x * shape_y
num_threads = tmp_min.shape[0]
num_items_per_thread = num_items // num_threads
min_val = data[0, 0]
min_ix, min_iy = 0, 0
for idx in range(num_items_per_thread):
idx = x * num_items_per_thread + idx
if idx < num_items:
ix, iy = idx // shape_y, idx % shape_y
current = data[ix, iy]
if current < min_val:
min_val = current
min_ix = ix
min_iy = iy
tmp_idx[x, 0] = min_ix
tmp_idx[x, 1] = min_iy
tmp_min[x] = min_val
cuda.syncthreads()
# find minimum in temporary array
if x == 0:
min_val = tmp_min[0]
min_x, min_y = tmp_idx[0, 0], tmp_idx[0, 1]
for idx in range(num_threads):
if tmp_min[idx] < min_val:
min_val = tmp_min[idx]
min_x, min_y = tmp_idx[idx, 0], tmp_idx[idx, 1]
argmin[0] = min_x
argmin[1] = min_y
其余的都是简单的元素操作。
【讨论】:
以上是关于索引为零维 cp.array 的 cp.array 慢切片(基于 cp.argmin 结果)的主要内容,如果未能解决你的问题,请参考以下文章