RNNLanguageModel 的 forward 方法的作用是啥?
Posted
技术标签:
【中文标题】RNNLanguageModel 的 forward 方法的作用是啥?【英文标题】:what is the role of RNNLanguageModel's forward method?RNNLanguageModel 的 forward 方法的作用是什么? 【发布时间】:2021-01-29 00:36:14 【问题描述】:我正在阅读有关使用 AllenNlp 框架的基于字符的神经网络的教程,目标是构建一个可以完成句子的模型。在我想训练我的模型之后,有一个实例构建步骤。我有下面的代码,我无法理解转发功能的作用,任何人都可以帮忙吗?有人可以举个例子吗
class RNNLanguageModel(Model):
def __init__(self,
embedder: TextFieldEmbedder,
hidden_size: int,
max_len: int,
vocab: Vocabulary) -> None:
super().__init__(vocab)
self.embedder = embedder
# initialize a Seq2Seq encoder, LSTM
self.rnn = PytorchSeq2SeqWrapper(
torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))
self.hidden2out = torch.nn.Linear(in_features=self.rnn.get_output_dim(), out_features=vocab.get_vocab_size('tokens'))
self.hidden_size = hidden_size
self.max_len = max_len
def forward(self, input_tokens, output_tokens):
'''
This is the main process of the Model where the actual computation happens.
Each Instance is fed to the forward method.
It takes dicts of tensors as input, with same keys as the fields in your Instance (input_tokens, output_tokens)
It outputs the results of predicted tokens and the evaluation metrics as a dictionary.
'''
mask = get_text_field_mask(input_tokens)
embeddings = self.embedder(input_tokens)
rnn_hidden = self.rnn(embeddings, mask)
out_logits = self.hidden2out(rnn_hidden)
loss = sequence_cross_entropy_with_logits(out_logits, output_tokens['tokens'], mask)
return 'loss': loss
【问题讨论】:
【参考方案1】:forward()
方法是我们实现模型“前向传递”的地方。这决定了输入(您的数据)如何流经您的模型以产生输出和损失值。
forward()
方法必须由继承自 PyTorch Module
的任何类实现,例如 AllenNLP 的 Model
类。
AllenNLP 最终只是 PyTorch 的一个更高级别的包装器,所以如果您对此感到困惑,我建议您先熟悉一下 PyTorch:https://pytorch.org/tutorials/
【讨论】:
以上是关于RNNLanguageModel 的 forward 方法的作用是啥?的主要内容,如果未能解决你的问题,请参考以下文章
[原创]java WEB学习笔记16:JSP指令(page,include),JSP标签(forwar,include,param)