每天讲解一点PyTorch nn.Embedding
Posted knowform
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了每天讲解一点PyTorch nn.Embedding相关的知识,希望对你有一定的参考价值。
今天我们讲解创建词嵌入模型
nn.Embedding(5, 6)
代表5个词(总共有多少个词)6维度(每个词的编码维度),5x6矩阵
>>> import torch
>>> import torch.nn as nn
>>>
>>> embeds = nn.Embedding(5, 6)
>>> embeds
Embedding(5, 6)
>>>
>>> import torch
>>> import torch.nn as nn
>>>
>>> embeds = nn.Embedding(5, 6)
>>> embeds
Embedding(5, 6)
>>> # I like cv
>>> # 1 2 3
>>> word = [[1,2,3],[0,0,3]]
>>> x = torch.LongTensor(word)
>>> x
tensor([[1, 2, 3],
[0, 0, 3]])
>>> embed_ = embeds(torch.LongTensor(word))
>>>
>>> embed_
tensor([[[ 1.1774, 0.0840, 1.2077, 0.3492, -0.1718, -0.2256],
[-1.1595, 0.2247, 0.8931, -0.6904, -0.3020, -0.4171],
[-2.2413, 0.3168, -0.1047, 0.8574, -0.2742, 0.1194]],
[[ 1.7274, 0.9688, -1.2030, 1.0693, -0.5663, 1.8714],
[ 1.7274, 0.9688, -1.2030, 1.0693, -0.5663, 1.8714],
[-2.2413, 0.3168, -0.1047, 0.8574, -0.2742, 0.1194]]],
grad_fn=<EmbeddingBackward>)
>>> embed_.shape
torch.Size([2, 3, 6])
>>>
输入维度:2x3的词
每个词编码:6维度向量
>>> word = [1,2,3]
>>> word
[1, 2, 3]
>>> embed_ = embeds(torch.LongTensor(word))
>>> embed_
tensor([[ 1.1774, 0.0840, 1.2077, 0.3492, -0.1718, -0.2256],
[-1.1595, 0.2247, 0.8931, -0.6904, -0.3020, -0.4171],
[-2.2413, 0.3168, -0.1047, 0.8574, -0.2742, 0.1194]],
grad_fn=<EmbeddingBackward>)
>>>
>>> embed_.shape
torch.Size([3, 6])
>>>
输入维度:3的词
每个词编码:6维度向量
>>> embeds = nn.Embedding(4, 6)
>>> embeds
Embedding(4, 6)
>>> word
[1, 2, 3]
>>> x = torch.LongTensor(word)
>>> x
tensor([1, 2, 3])
>>>
>>> embed_ = embeds(x)
>>> embed_
tensor([[ 1.8449, 0.3119, -0.0317, -0.7793, -0.6710, 0.0316],
[ 0.8638, -1.2257, 0.5386, -0.6412, 1.4129, 0.0926],
[ 0.5742, 1.2701, 0.3415, 0.3161, -1.7954, 0.8924]],
grad_fn=<EmbeddingBackward>)
>>>
>>> embeds = nn.Embedding(len("0123456789"), 6)
>>> embeds
Embedding(10, 6)
>>>
以上是关于每天讲解一点PyTorch nn.Embedding的主要内容,如果未能解决你的问题,请参考以下文章