pytorch 中的 torch.nn.gru 函数的输入是啥?

Posted

技术标签:

【中文标题】pytorch 中的 torch.nn.gru 函数的输入是啥?【英文标题】:what is the inputs to a torch.nn.gru function in pytorch?pytorch 中的 torch.nn.gru 函数的输入是什么? 【发布时间】:2020-03-23 22:15:07 【问题描述】:

我正在使用 gru 函数来实现 RNN。这个 RNN (GRU) 在一些 CNN 层之后使用。有人可以告诉我这里 GRU 函数的输入是什么吗?特别是,隐藏的大小是固定的吗?

self.gru = torch.nn.GRU(
            input_size=input_size,
            hidden_size=128,
            num_layers=1,
            batch_first=True,
            bidirectional=True) 

根据我的理解,输入大小将是特征的数量,而 GRU 的隐藏大小总是固定为 128?有人可以纠正我。或提供他们的反馈

【问题讨论】:

【参考方案1】:

首先,GRU 不是一个函数,而是一个类,您正在调用它的构造函数。你在这里创建了一个 GRU 类的实例,它是一个层(或 pytorch 中的 Module)。

input_size 必须与前一个 CNN 层的out_channels 匹配。

您看到的所有参数都不是固定的。只需在此处输入另一个值,它将是其他值,即将 128 替换为您喜欢的任何值。

即使它被称为hidden_size,对于 GRU,此参数也决定了输出特征。换句话说,如果您在 GRU 之后还有另一层,则该层的 input_size(或 in_featuresin_channels 或其他任何名称)必须与 GRU 的 hidden_size 匹配。

另外,看看documentation。这会准确地告诉您传递给构造函数的参数的用途。此外,它会告诉您在实际使用层后(通过self.gru(...))的预期输入是什么,以及该调用的输出是什么。

【讨论】:

以上是关于pytorch 中的 torch.nn.gru 函数的输入是啥?的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-54]:循环神经网络 - torch.nn.GRU()参数详解

PyTorch笔记 - GRU(Gated Recurrent Unit)

PyTorch笔记 - GRU(Gated Recurrent Unit)网络结构

PyTorch笔记 - GRU(Gated Recurrent Unit)网络结构

pytorch中的顺序容器——torch.nn.Sequential

Pytorch模型量化实践并以ResNet18模型量化为例(附代码)