如何在 Pytorch 的“nn.Sequential”中展平输入
Posted
技术标签:
【中文标题】如何在 Pytorch 的“nn.Sequential”中展平输入【英文标题】:how to flatten input in `nn.Sequential` in Pytorch 【发布时间】:2019-05-25 23:30:54 【问题描述】:如何在nn.Sequential
中展平输入
Model = nn.Sequential(x.view(x.shape[0],-1),
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim=1))
【问题讨论】:
【参考方案1】:您可以如下创建一个新模块/类,并在使用其他模块时按顺序使用它(调用Flatten()
)。
class Flatten(torch.nn.Module):
def forward(self, x):
batch_size = x.shape[0]
return x.view(batch_size, -1)
参考:https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983
编辑:Flatten
现在是火炬的一部分。见https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten
【讨论】:
或者直接在forward
方法中调用out = x.view(batch_size, -1)
。
@DanielMöller 再看问题,OP 想用nn.Sequential
来做
知道了,你的答案很完美。【参考方案2】:
定义为flatten
method
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
速度与view()
相当,但reshape
更快。
import torch.nn as nn
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)
#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)
#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)
#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)
速度检查
# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
如果我们使用上面的类
flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)
5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
这个结果表明创建一个类会比较慢。这就是为什么将张量向前展平更快的原因。我认为这是他们没有推广nn.Flatten
的主要原因。
所以我的建议是使用内锋来提高速度。像这样的:
out = inp.reshape(inp.size(0), -1)
【讨论】:
【参考方案3】:你可以如下修改你的代码,
Model = nn.Sequential(nn.Flatten(0, -1),
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim=1))
【讨论】:
以上是关于如何在 Pytorch 的“nn.Sequential”中展平输入的主要内容,如果未能解决你的问题,请参考以下文章