每天讲解一点PyTorch nn.Embedding

Posted cv.exp

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了每天讲解一点PyTorch nn.Embedding相关的知识,希望对你有一定的参考价值。

今天我们讲解创建词嵌入模型
nn.Embedding(5, 6)
代表5个词(总共有多少个词)6维度(每个词的编码维度),5x6矩阵

>>> import torch
>>> import torch.nn as nn
>>> 
>>> embeds = nn.Embedding(5, 6)
>>> embeds
Embedding(5, 6)
>>> 
>>> import torch
>>> import torch.nn as nn
>>> 
>>> embeds = nn.Embedding(5, 6)
>>> embeds
Embedding(5, 6)
>>> # I like cv
>>> # 1 2 3
>>> word = [[1,2,3],[0,0,3]]
>>> x = torch.LongTensor(word)
>>> x
tensor([[1, 2, 3],
        [0, 0, 3]])

>>> embed_ = embeds(torch.LongTensor(word))
>>> 
>>> embed_
tensor([[[ 1.1774,  0.0840,  1.2077,  0.3492, -0.1718, -0.2256],
         [-1.1595,  0.2247,  0.8931, -0.6904, -0.3020, -0.4171],
         [-2.2413,  0.3168, -0.1047,  0.8574, -0.2742,  0.1194]],

        [[ 1.7274,  0.9688, -1.2030,  1.0693, -0.5663,  1.8714],
         [ 1.7274,  0.9688, -1.2030,  1.0693, -0.5663,  1.8714],
         [-2.2413,  0.3168, -0.1047,  0.8574, -0.2742,  0.1194]]],
       grad_fn=<EmbeddingBackward>)
>>> embed_.shape
torch.Size([2, 3, 6])
>>> 

输入维度:2x3的词
每个词编码:6维度向量

>>> word = [1,2,3]
>>> word
[1, 2, 3]
>>> embed_ = embeds(torch.LongTensor(word))
>>> embed_
tensor([[ 1.1774,  0.0840,  1.2077,  0.3492, -0.1718, -0.2256],
        [-1.1595,  0.2247,  0.8931, -0.6904, -0.3020, -0.4171],
        [-2.2413,  0.3168, -0.1047,  0.8574, -0.2742,  0.1194]],
       grad_fn=<EmbeddingBackward>)
>>> 
>>> embed_.shape
torch.Size([3, 6])
>>> 

输入维度:3的词
每个词编码:6维度向量

>>> embeds = nn.Embedding(4, 6)
>>> embeds
Embedding(4, 6)
>>> word
[1, 2, 3]
>>> x = torch.LongTensor(word)
>>> x
tensor([1, 2, 3])
>>> 
>>> embed_ = embeds(x)
>>> embed_
tensor([[ 1.8449,  0.3119, -0.0317, -0.7793, -0.6710,  0.0316],
        [ 0.8638, -1.2257,  0.5386, -0.6412,  1.4129,  0.0926],
        [ 0.5742,  1.2701,  0.3415,  0.3161, -1.7954,  0.8924]],
       grad_fn=<EmbeddingBackward>)
>>> 
>>> embeds = nn.Embedding(len("0123456789"), 6)
>>> embeds
Embedding(10, 6)
>>> 

以上是关于每天讲解一点PyTorch nn.Embedding的主要内容,如果未能解决你的问题,请参考以下文章

每天讲解一点PyTorch F.softmax

每天讲解一点PyTorch torch.matmul

每天讲解一点PyTorch isinstance

每天讲解一点PyTorch isinstance

每天讲解一点PyTorch 12enumerate

每天讲解一点PyTorch 12enumerate