nn.LSTMCell 的火炬 0.4.0 nn.LayerNorm 示例的任何示例?
Posted
技术标签:
【中文标题】nn.LSTMCell 的火炬 0.4.0 nn.LayerNorm 示例的任何示例?【英文标题】:Any example of torch 0.4.0 nn.LayerNorm example for nn.LSTMCell? 【发布时间】:2018-10-13 07:01:53 【问题描述】:在 pytorch 0.4.0 版本中,有一个 nn.LayerNorm 模块。
我想在我的 LSTM 网络中实现这一层,虽然我在 LSTM 网络上找不到任何实现示例。
pytorch 贡献者暗示这个nn.LayerNorm
仅适用于nn.LSTMCell
s。
如果我能在nn.LSTMcell
或任何torch LSTM 网络上获得任何git repo 或一些实现nn.LayerNorm
的代码,那将是一个很大的帮助。
提前致谢
【问题讨论】:
【参考方案1】:我也在寻找解决方案。这是来自https://github.com/pytorch/pytorch/issues/11335的示例 感谢@jinserk
class LayerNormLSTMCell(nn.LSTMCell):
def __init__(self, input_size, hidden_size, bias=True):
super().__init__(input_size, hidden_size, bias)
self.ln_ih = nn.LayerNorm(4 * hidden_size)
self.ln_hh = nn.LayerNorm(4 * hidden_size)
self.ln_ho = nn.LayerNorm(hidden_size)
def forward(self, input, hidden=None):
self.check_forward_input(input)
if hidden is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
else:
hx, cx = hidden
self.check_forward_hidden(input, hx, '[0]')
self.check_forward_hidden(input, cx, '[1]')
gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \
+ self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh))
i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1)
g = gates[:, (3 * self.hidden_size):].tanh()
cy = (f * cx) + (i * g)
hy = o * self.ln_ho(cy).tanh()
return hy, cy
class LayerNormLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
num_directions = 2 if bidirectional else 1
self.hidden0 = nn.ModuleList([
LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
hidden_size=hidden_size, bias=bias)
for layer in range(num_layers)
])
if self.bidirectional:
self.hidden1 = nn.ModuleList([
LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
hidden_size=hidden_size, bias=bias)
for layer in range(num_layers)
])
def forward(self, input, hidden=None):
seq_len, batch_size, hidden_size = input.size() # supports TxNxH only
num_directions = 2 if self.bidirectional else 1
if hidden is None:
hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
else:
hx, cx = hidden
ht = [[None, ] * (self.num_layers * num_directions)] * seq_len
ct = [[None, ] * (self.num_layers * num_directions)] * seq_len
if self.bidirectional:
xs = input
for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)):
l0, l1 = 2 * l, 2 * l + 1
h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1]
for t, (x0, x1) in enumerate(zip(xs, reversed(xs))):
ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0))
h0, c0 = ht[t][l0], ct[t][l0]
t = seq_len - 1 - t
ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1))
h1, c1 = ht[t][l1], ct[t][l1]
xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht]
y = torch.stack(xs)
hy = torch.stack(ht[-1])
cy = torch.stack(ct[-1])
else:
h, c = hx, cx
for t, x in enumerate(input):
for l, layer in enumerate(self.hidden0):
ht[t][l], ct[t][l] = layer(x, (h[l], c[l]))
x = ht[t][l]
h, c = ht[t], ct[t]
y = torch.stack([h[-1] for h in ht])
hy = torch.stack(ht[-1])
cy = torch.stack(ct[-1])
return y, (hy, cy)
【讨论】:
以上是关于nn.LSTMCell 的火炬 0.4.0 nn.LayerNorm 示例的任何示例?的主要内容,如果未能解决你的问题,请参考以下文章