pytorch列优先(fortran-like)reshape的实现与性能
Posted cyoahs
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch列优先(fortran-like)reshape的实现与性能相关的知识,希望对你有一定的参考价值。
背景
numpy
中的reshape
函数包含一个order
变量,默认order=\'C\'
,即在变形中以前面的维度(行)为优先顺序重新排列元素,而order=\'F\'
时以后面的维度(列)为优先顺序重新排列元素,官方文档中给出了示例:
>>> np.reshape(a, (2, 3)) # C-like index ordering
array([[0, 1, 2],
[3, 4, 5]])
>>> np.reshape(np.ravel(a), (2, 3)) # equivalent to C ravel then C reshape
array([[0, 1, 2],
[3, 4, 5]])
>>> np.reshape(a, (2, 3), order=\'F\') # Fortran-like index ordering
array([[0, 4, 3],
[2, 1, 5]])
>>> np.reshape(np.ravel(a, order=\'F\'), (2, 3), order=\'F\')
array([[0, 4, 3],
[2, 1, 5]])
pytorch解决方案
在pytorch
中,torch.reshape()
函数只接受矩阵和形状两个参数,采用了行优先(C-Style)的变换方式,如果需要使用列优先的变换,需要借助permute()
函数,stackoverflow上给出了解决方案:
def reshape_fortran(x, shape):
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
性能测试
但是上面的作者怀疑permute()
函数内部仍然会创建张量的副本,影响效率。因此笔者对这种方法做了测试,并与numpy
的内置函数做了对比。测试环境为i9-10900X/RTX2080Ti。
测试代码:
import numpy as np
import torch
import time
dim1 = 40
dim2 = 50
dim3 = 5
def reshape_fortran(x, shape):
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
torch.cuda.set_device(0)
device = torch.device(\'cuda\')
x = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
xx = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
for i in range(100):
y = x[i].reshape([dim2, dim1])
# c reshape
t0 = time.time()
for i in range(100):
y = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()
# fortran reshape
for i in range(100):
yy = reshape_fortran(xx[i], [dim2, dim3, -1])
t2 = time.time()
print(f\'torch build-in reshape: {(t1 - t0)/100} s\')
print(f\'torch permute reshape: {(t2 - t1)/100} s\')
x = [np.random.rand(dim1, dim2) for _ in range(100)]
xx = [np.random.rand(dim1, dim2) for _ in range(100)]
for i in range(100):
y = x[i].reshape([dim2, dim3, -1])
t0 = time.time()
for i in range(100):
yy = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()
for i in range(100):
yyy = xx[i].reshape([dim2, dim3, -1], order=\'F\')
t2 = time.time()
print(f\'numpy C reshape: {(t1 - t0)/100} s\')
print(f\'numpy F reshape: {(t2 - t1)/100} s\')
测试结果:
torch build-in reshape: 9.72747802734375e-07 s
torch permute reshape: 1.1897087097167968e-05 s
numpy C reshape: 3.0517578125e-07 s
numpy F reshape: 2.474784851074219e-06 s
测试中pytorch
中基于permute()
的方法的耗时是内置行优先reshape()
函数的10倍,但是在numpy
的测试中,列优先变换的耗时也是行优先的10倍。因此可以认为在pytorch
中,基于permute()
函数的变换计算效率很高,不需要继续优化。
参考文献
numpy文档
stackoverflow原问题
以上是关于pytorch列优先(fortran-like)reshape的实现与性能的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch——Python 优先的深度学习框架|软件推荐