关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示
Posted ZSYL
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示相关的知识,希望对你有一定的参考价值。
关于 PyTorch 特殊方法的演示
1. decoder_outputs[:,t,:] = decoder_output_t
decoder_outputs 形状 [batch_size, seq_len, vocab_size]
decoder_output_t 形状[batch_size, vocab_size]
示例代码:
import torch
a = torch.zeros((2, 3, 5))
print(a.size())
print(a)
b = torch.randn((2, 5))
print(b.size())
print(b)
a[:, 0, :] = b
print(a.size())
print(a)
运行结果:
2. torch.topk, torch.max(), torch.argmax()
value, index = torch.topk(decoder_output_t , k = 1)
decoder_output_t [batch_size, vocab_size]
示例代码:
import torch
a = torch.randn((3, 5))
print(a.size())
print(a)
values, index = torch.topk(a, k=1)
print(values)
print(index)
print(index.size())
values, index = torch.max(a, dim=-1)
print(values)
print(index)
print(index.size())
index = torch.argmax(a, dim=-1)
print(index)
print(index.size())
index = a.argmax(dim=-1)
print(index)
print(index.size())
运行结果:
3. unsqueeze()
若使用 teacher forcing
,将采用下次真实值作为下个time step
的输入
# 注意unsqueeze 相当于浅拷贝,不会对原张量进行修改
decoder_input = target[:,t].unsqueeze(-1)
target 形状 [batch_size, seq_len]
decoder_input 要求形状[batch_size, 1]
示例代码:
import torch
a = torch.randn((3, 5))
print(a.size())
print(a)
b = a[:, 3]
print(b.size())
print(b)
c = b.unsqueeze(-1)
print(c.size())
print(c)
运行结果:
参考:link
以上是关于关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示的主要内容,如果未能解决你的问题,请参考以下文章