训练 LSTM 时跑出 Ram
Posted
技术标签:
【中文标题】训练 LSTM 时跑出 Ram【英文标题】:Ran out of Ram while training LSTM 【发布时间】:2020-03-08 14:26:29 【问题描述】:我是 RNN 的初学者,所以我使用 Pytorch 编写了一个 LSTM 架构,但每当我进入第三个 epoch 时,我总是会用完 RAM。我已经在使用 DataLoader,我试图从输入张量中分离梯度,但它并没有解决问题。
这是我的训练循环
writer = SummaryWriter()
criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index = 0)
optimizer = optim.Adam(lstm.parameters(), lr = 1e-5)
gradient_clip = clip_grad_norm_(lstm.parameters(), max_norm = 5)
num_epochs = 20
epoch_loss = -1.0
loss = - 1
t = trange(num_epochs, desc= "Epoch loss", leave=True)
for epoch in t:
trainLoader = iter(DataLoader(dataset, batch_size = batch_size))
tt = trange(len(trainLoader)-1, desc= "Batch loss", leave=True)
for i in tt:
text, embedding = next(trainLoader)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
y = lstm.forward(embedding.transpose(1,0))
labels = text.transpose(0,1)[1:].transpose(0,1).flatten()
loss = criterion(y.reshape(-1, y.shape[-1]), labels)
tt.set_description("Batch loss : %.4f" % loss)
tt.refresh()
loss.backward(retain_graph=True)
optimizer.step()
epoch_loss += loss
epoch_loss = epoch_loss / (len(trainLoader) - 1)
# Saving model
save_date = datetime.now().strftime("%d%m%Y-%H:%M:%S")
PATH = './save/lstm_model_'+save_date
torch.save(lstm, PATH)
# Updating progression bar
t.set_description("Epoch loss : %.4f" % epoch_loss)
t.refresh()
# Plotting gradients histograms in Tensorboard
writer.add_scalar('Text_generation_Loss/train', epoch_loss, epoch)
for tag, parm in lstm.named_parameters():
with torch.no_grad():
writer.add_histogram(tag, parm.grad.data.cpu().numpy(), epoch)
writer.flush()
print('Finished Training')
writer.close()
这是我构建的 LSTM 类:
class LSTM(nn.Module):
def __init__(self, in_size : int, hidden_size : int):
super().__init__()
self.in_size = in_size
self.hidden_size = hidden_size
self.W_fi = nn.Linear(in_size,hidden_size)
self.W_fh = nn.Linear(hidden_size,hidden_size, bias=False)
self.W_ii = nn.Linear(in_size,hidden_size)
self.W_ih = nn.Linear(hidden_size,hidden_size, bias=False)
self.W_Ci = nn.Linear(in_size,hidden_size)
self.W_Ch = nn.Linear(hidden_size,hidden_size, bias=False)
self.W_oi = nn.Linear(in_size,hidden_size)
self.W_oh = nn.Linear(hidden_size,hidden_size, bias=False)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
def one_step(self, x, h, C):
f_t = self.sigmoid(self.W_fi(x) + self.W_fh(h))
i_t = self.sigmoid(self.W_ii(x) + self.W_ih(h))
g_t = self.tanh(self.W_Ci(x) + self.W_Ch(h))
C_t = torch.mul(f_t, C) + torch.mul(i_t, g_t)
o_t = self.sigmoid(self.W_oi(x) + self.W_oh(h))
h_t = torch.mul(o_t, self.tanh(C_t))
return h_t, C_t
def forward(self, X):
h_out = []
h = - torch.ones(X.shape[1], self.hidden_size)
C = - torch.ones(X.shape[1], self.hidden_size)
h_t, C_t = self.one_step(X[0], h, C)
h_out.append(h_t)
for i in range(1, X.shape[0] - 1):
h_t, C_t = self.one_step(X[i], h_t, C_t)
h_out.append(h_t)
h_out = torch.cat(h_out)
return h_out #h_out.reshape(-1,batch_size,num_embeddings)
我已经搜索过类似的案例,但我找不到解决方案
【问题讨论】:
您是否要求我们帮助您获得更多内存?... 不,我只是问我在 Pytorch 中是否有什么问题,因为我还是个初学者 【参考方案1】:我不知道它是否可以帮助某人,但我解决了这个问题。我可能不清楚任务,但目标是生成文本。我做的第一件事是使用在我的 LSTM 之外定义的 torch.nn.embedding 嵌入句子。解决方案是将其作为我网络的一层,因为嵌入不是预训练的,也应该学习。
【讨论】:
以上是关于训练 LSTM 时跑出 Ram的主要内容,如果未能解决你的问题,请参考以下文章