关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示相关的知识,希望对你有一定的参考价值。

1. decoder_outputs[:,t,:] = decoder_output_t

decoder_outputs 形状 [batch_size, seq_len, vocab_size]
decoder_output_t 形状[batch_size, vocab_size]

示例代码:

import torch
 
a = torch.zeros((2, 3, 5))
print(a.size())
print(a)
 
b = torch.randn((2, 5))
print(b.size())
print(b)
 
a[:, 0, :] = b
print(a.size())
print(a)

运行结果:

2. torch.topk, torch.max(), torch.argmax()

value, index = torch.topk(decoder_output_t , k = 1)
decoder_output_t [batch_size, vocab_size]

示例代码:

import torch
 
a = torch.randn((3, 5))
print(a.size())
print(a)
 
values, index = torch.topk(a, k=1)
print(values)
print(index)
print(index.size())
 
values, index = torch.max(a, dim=-1)
print(values)
print(index)
print(index.size())
 
index = torch.argmax(a, dim=-1)
print(index)
print(index.size())
 
index = a.argmax(dim=-1)
print(index)
print(index.size())

运行结果:

3. unsqueeze()

若使用 teacher forcing ,将采用下次真实值作为下个time step的输入

# 注意unsqueeze 相当于浅拷贝,不会对原张量进行修改
decoder_input = target[:,t].unsqueeze(-1)
target 形状 [batch_size, seq_len]
decoder_input 要求形状[batch_size, 1]

示例代码:

import torch
 
a = torch.randn((3, 5))
print(a.size())
print(a)
 
b = a[:, 3]
print(b.size())
print(b)
c = b.unsqueeze(-1)
print(c.size())
print(c)

运行结果:


参考link

以上是关于关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示的主要内容,如果未能解决你的问题,请参考以下文章

关于 Where 子句的 T-Sql 多重标准

关于tcpl习题4-14定义宏swap(t,x,y)

关于 AT&T 汇编语法 (%esp,1)

关于T/G/M/K

需要关于两种语言 S* 和 T* 的递归定义的帮助,其中 S=aa,b 和 T=w1,w2,w3,w4

关于删除树中指定节点的实例分析