AttributeError:当向 Pytorch LSTM 网络提供输入时,“元组”对象没有属性“dim”
Posted
技术标签:
【中文标题】AttributeError:当向 Pytorch LSTM 网络提供输入时,“元组”对象没有属性“dim”【英文标题】:AttributeError: 'tuple' object has no attribute 'dim', when feeding input to Pytorch LSTM network 【发布时间】:2019-04-01 14:56:36 【问题描述】:我正在尝试运行以下代码:
import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_shape, n_actions):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_shape, 12)
self.hidden2tag = nn.Linear(12, n_actions)
def forward(self, x):
out = self.lstm(x)
out = self.hidden2tag(out)
return out
state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]
device = torch.device("cuda")
net = LSTM(5, 3).to(device)
state_v = torch.FloatTensor(state).to(device)
q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())
这会返回此错误:
Traceback (most recent call last):
File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/Question***.py", line 26, in <module>
q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/Question***.py", line 15, in forward
out = self.hidden2tag(out)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
return F.linear(input, self.weight, self.bias)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'
有人知道如何解决这个问题吗? (摆脱作为元组的张量,以便将其输入 LSTM 网络)
【问题讨论】:
【参考方案1】:pytorch LSTM 返回一个元组。因此您会收到此错误,因为您的线性层 self.hidden2tag
无法处理此元组。
所以改变:
out = self.lstm(x)
到
out, states = self.lstm(x)
这将通过拆分元组来修复您的错误,以便 out
只是您的输出张量。
out
然后存储隐藏状态,而states
是另一个包含最后隐藏状态和单元状态的元组。
您也可以在这里查看:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM
由于max()
也会返回一个元组,所以最后一行会出现另一个错误。但这应该很容易解决,而且是不同的错误:)
【讨论】:
【参考方案2】:首先在一个 numpy 数组中转换你的状态:
state = np.array(state)
PyTorch 的 API 中可能缺少 np.asarray
。
【讨论】:
以上是关于AttributeError:当向 Pytorch LSTM 网络提供输入时,“元组”对象没有属性“dim”的主要内容,如果未能解决你的问题,请参考以下文章
AttributeError:“str”对象在 pytorch 中没有属性“dim”
pytorch版本问题:AttributeError: 'module' object has no attribute '_rebuild_tensor_v2'(示例
Pytorch Text AttributeError:“BucketIterator”对象没有属性
Unpickling 保存的 pytorch 模型会引发 AttributeError: Can't get attribute 'Net' on <module '__main__' 尽管内联
pytorch tensorboardX可视化问题:AttributeError: 'torch._C.Value' object has no attribute 'debu
不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘