[Pytorch系列-56]:循环神经网络 - word2vec词向量表Embedding/Glove的定义与读访问
Posted 文火冰糖的硅基工坊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-56]:循环神经网络 - word2vec词向量表Embedding/Glove的定义与读访问相关的知识,希望对你有一定的参考价值。
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121707481
目录
第1章 Embedding词嵌入概述
在pytorch,Embedding实际上就是指:定义n个单词组成的词向量表。
通过实例化该词向量表,可以实现词向量的训练和词向量的访问。
在pytorch中,词向量表是通过torch.nn.Embedding类来实现的。
第2章 torch.nn.Embedding 的使用说明
(1)类初始化说明
torch.nn.Embedding (num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2,scale_grad_by_freq=False, sparse=False)
- num_embeddings:嵌入字典的大小(单词的个数);
- embedding_dim:每个嵌入向量的大小;
- padding_idx:若给定,则每遇到 padding_idx 时,位于 padding_idx 的嵌入向量(即 padding_idx 映射所对应的向量)为0;
- max_norm:若给定,则每个大于 max_norm 的数都会被规范化为 max_norm;
- norm_type:为 max_norm 计算 p-范数的 p值;
- scale_grad_by_freq:若给定,则将按照 mini-batch 中 words 频率的倒数 scale gradients;
- sparse:若为 True,则 weight 矩阵将是稀疏张量。
(2)类的实例化(对象)
embedding = nn.Embedding(10, 300)
实例化一个词向量表,其中10表示:一个有10个单词,300表示每个单词的词向量的维度或长度。
(3)从Embedding此向量表中读取词向量
- 输入:含有待提取 indices 的任意 shape的词的索引列表,一次可以输入多个单词的索引。
- 输出:输出 向量多个索引对应的词向量,其形状为:shape =(*,H),其中 * 为输入单词索引对应的列表,H = embedding_dim(若输入 shape 为 N*M,则输出 shape 为 N*M*H);
第3章 torch.nn.Embedding代码实例
3.1 前提条件
#环境准备
import numpy as np # numpy数组库
import math # 数学运算库
import matplotlib.pyplot as plt # 画图库
import time as time
import torch # torch基础库
import torch.nn as nn # torch神经网络库
import torch.nn.functional as F
import torchnlp
from torchnlp.word_to_vector import GloVe
# from torchnlp.word_to_vector import Glove
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.backends.cudnn.version())
Hello World 1.10.0 True 10.2 7605
3.2 定义单词到索引的映射表(字典)
#定义一个字符到数字index的字典
word_to_index = "hello":0, "world":1, "I": 2, "Love":3, "you":4
print(word_to_index)
#获取字符串对应的索引
index_hello = word_to_index["hello"]
index_world = word_to_index["world"]
print(index_hello)
print(index_world)
'hello': 0, 'world': 1, 'I': 2, 'Love': 3, 'you': 4 0 1
3.3 自定义一个采用默认初始化的词向量表
#定义一个长度为2,宽度为30的二维词向量表
#每次生成的词向量表都是随机的
embed_code = nn.Embedding(10,30)
print(embed_code)
Embedding(10, 30)
备注:上述定义的词向量,还没有进行训练,使用随机初始化的方式初始化词向量表。
该表还需要使用特定的文本数据进行训练,才能获得某个特定领域理想的词向量表。
3.4 从词向量实例化后的表中读取词n个向量
(1)直接使用索引读取词向量
# 生成一个以index为数据的多维的tensor
lookup_tensor = torch.tensor([0], dtype=torch.long)
print(lookup_tensor)
#查找指定tensor对应的词向量
# 每个单词用一个索引指示
# 每个单词对应的词向量是一个多维向量
hello_embed = embed_code(lookup_tensor)
print(hello_embed)
tensor([0]) tensor([[ 0.2982, -1.0843, 0.7222, 0.5297, 0.9194, 1.9022, -0.0898, -1.4350, -0.4850, -1.2868, 0.3609, -0.1462, 0.9040, -0.2600, -1.1349, -0.6525, -2.7043, -0.7888, 0.2020, 1.7826, 0.7497, -1.4442, 0.1443, 2.7103, -1.0916, -0.7368, -0.6998, 0.7912, -0.6554, 0.1601]], grad_fn=<EmbeddingBackward0>)
备注:使用tensor读取词向量,因此需要先把index转换成tensor,然后再读取其对应的词向量。
(2)根据多个索引列表读取词向量
# 生成一个以index为数据的多维的tensor
lookup_tensor = torch.tensor([1, 0], dtype=torch.long)
print(lookup_tensor)
#查找指定tensor对应的词向量
# 每个单词用一个索引指示
# 每个单词对应的词向量是一个多维向量
hello_embed = embed_code(lookup_tensor)
print(hello_embed)
tensor([1, 0]) tensor([[-0.2788, -0.4159, -0.7487, -1.2894, 0.5326, -0.1796, 1.5914, -1.0262, -1.0611, -0.6709, 0.4998, -0.5366, 0.3190, 0.2558, 0.2866, 0.3740, 1.4530, 0.6485, -0.4284, -0.4048, 1.0075, 0.1939, -1.3701, -1.2756, -0.4662, -0.0618, 0.5167, 1.3689, -0.3640, 1.1894], [ 0.2982, -1.0843, 0.7222, 0.5297, 0.9194, 1.9022, -0.0898, -1.4350, -0.4850, -1.2868, 0.3609, -0.1462, 0.9040, -0.2600, -1.1349, -0.6525, -2.7043, -0.7888, 0.2020, 1.7826, 0.7497, -1.4442, 0.1443, 2.7103, -1.0916, -0.7368, -0.6998, 0.7912, -0.6554, 0.1601]], grad_fn=<EmbeddingBackward0>)
(3)根据单个单词读取词向量
lookup_tensor = torch.tensor([index_hello], dtype=torch.long)
print(lookup_tensor)
#查找指定tensor对应的词向量
# 每个单词用一个索引指示
# 每个单词对应的词向量是一个多维向量
hello_embed = embed_code(lookup_tensor)
print(hello_embed)
tensor([0]) tensor([[ 0.2982, -1.0843, 0.7222, 0.5297, 0.9194, 1.9022, -0.0898, -1.4350, -0.4850, -1.2868, 0.3609, -0.1462, 0.9040, -0.2600, -1.1349, -0.6525, -2.7043, -0.7888, 0.2020, 1.7826, 0.7497, -1.4442, 0.1443, 2.7103, -1.0916, -0.7368, -0.6998, 0.7912, -0.6554, 0.1601]], grad_fn=<EmbeddingBackward0>) tensor([0]) tensor([[ 0.2982, -1.0843, 0.7222, 0.5297, 0.9194, 1.9022, -0.0898, -1.4350, -0.4850, -1.2868, 0.3609, -0.1462, 0.9040, -0.2600, -1.1349, -0.6525, -2.7043, -0.7888, 0.2020, 1.7826, 0.7497, -1.4442, 0.1443, 2.7103, -1.0916, -0.7368, -0.6998, 0.7912, -0.6554, 0.1601]], grad_fn=<EmbeddingBackward0>)
根据单词的字符串读取词向量,需先把单词转换成对应的index,以及其tensor。
(4)根据多个单词读取多个词向量
print("index_hello=", index_hello)
print("index_world=", index_world)
lookup_tensor = torch.tensor([index_hello, index_world], dtype=torch.long)
print(lookup_tensor)
#查找指定tensor对应的词向量
# 每个单词用一个索引指示
# 每个单词对应的词向量是一个多维向量
hello_embed = embed_code(lookup_tensor)
print(hello_embed)
index_hello= 0 index_world= 1 tensor([0, 1]) tensor([[ 0.2982, -1.0843, 0.7222, 0.5297, 0.9194, 1.9022, -0.0898, -1.4350, -0.4850, -1.2868, 0.3609, -0.1462, 0.9040, -0.2600, -1.1349, -0.6525, -2.7043, -0.7888, 0.2020, 1.7826, 0.7497, -1.4442, 0.1443, 2.7103, -1.0916, -0.7368, -0.6998, 0.7912, -0.6554, 0.1601], [-0.2788, -0.4159, -0.7487, -1.2894, 0.5326, -0.1796, 1.5914, -1.0262, -1.0611, -0.6709, 0.4998, -0.5366, 0.3190, 0.2558, 0.2866, 0.3740, 1.4530, 0.6485, -0.4284, -0.4048, 1.0075, 0.1939, -1.3701, -1.2756, -0.4662, -0.0618, 0.5167, 1.3689, -0.3640, 1.1894]], grad_fn=<EmbeddingBackward0>)
第4章 Glove()预定义的词向量数据集代码实例
4.1 概述
上述定义的nn.Embedding(10,30)词向量表,采用随机值进行了初始化,这种随机值初始化后向量值,并不能反映特定单词与单词之间的余弦距离以及不同单词之间的相关性。
nn.Embedding(10,30)词向量还需要根据特定的文本对其进行训练,以符合特定业务场景的需求。
Glove()是预定义的,他人已经根据大量的文本数据对其进行过训练后的词向量数据数据集
Glove是他人已经训练好的数据集。
GloVe是一种词向量的编码方式, 其词向量库有2.18G大小,一般不用自己修改
词向量是对“某个单词”,而不是单个字母,进行向量化。
4.2 定义Glove词向量库实例
vectors = torchnlp.word_to_vector.GloVe()
print(vectors)
glove.840B.300d.txt
备注:840条词向量,每个词向量的维度是300个维度。
4.3 使用任意单词进行词向量提取
vector = vectors["hello"]
print(vector.shape)
print(vector)
torch.Size([300]) tensor([ 0.2523, 0.1018, -0.6748, 0.2112, 0.4349, 0.1654, 0.4826, -0.8122, 0.0413, 0.7850, -0.0779, -0.6632, 0.1464, -0.2929, -0.2549, 0.0193, -0.2026, 0.9823, 0.0283, -0.0813, -0.1214, 0.1313, -0.1765, 0.1356, -0.1636, -0.2257, 0.0550, -0.2031, 0.2072, 0.0958, 0.2248, 0.2154, -0.3298, -0.1224, -0.4003, -0.0794, -0.1996, -0.0151, -0.0791, -0.1813, 0.2068, -0.3620, -0.3074, -0.2442, -0.2311, 0.0980, 0.1463, -0.0627, 0.4293, -0.0780, -0.1963, 0.6509, -0.2281, -0.3031, -0.1248, -0.1757, -0.1465, 0.1536, -0.2952, 0.1510, -0.5173, -0.0336, -0.2311, -0.7833, 0.0180, -0.1572, 0.0229, 0.4964, 0.0292, 0.0567, 0.1462, -0.1919, 0.1624, 0.2390, 0.3643, 0.4526, 0.2456, 0.2380, 0.3140, 0.3487, -0.0358, 0.5611, -0.2535, 0.0520, -0.1062, -0.3096, 1.0585, -0.4202, 0.1822, -0.1126, 0.4058, 0.1178, -0.1971, -0.0753, 0.0807, -0.0278, -0.1562, -0.4468, -0.1516, 0.1692, 0.0983, -0.0319, 0.0871, 0.2608, 0.0027, 0.1319, 0.3444, -0.3789, -0.4114, 0.0816, -0.1167, -0.4371, 0.0111, 0.0994, 0.2661, 0.4002, 0.1890, -0.1844, -0.3036, -0.2725, 0.2247, -0.4061, 0.1562, -0.1604, 0.4715, 0.0080, 0.5686, 0.2193, -0.1118, 0.7993, 0.1071, -0.5015, 0.0636, 0.0695, 0.1529, -0.2747, -0.2099, 0.2074, -0.1068, 0.4065, -2.6438, -0.3114, -0.3216, -0.2646, -0.3562, 0.0700, -0.1884, 0.4877, -0.2617, -0.0208, 0.1782, 0.1576, -0.1375, 0.0565, 0.3077, -0.0661, 0.4748, -0.2734, 0.0973, -0.2083, 0.0039, 0.3460, -0.0870, -0.5492, -0.1876, -0.1717, 0.0603, -0.1352, 0.1042, 0.3016, 0.0580, 0.2187, -0.0736, -0.2042, -0.2528, -0.1047, -0.3216, 0.1252, -0.3128, 0.0097, -0.2678, -0.6112, -0.1109, -0.1365, 0.0351, -0.4939, 0.0849, -0.1549, -0.0635, -0.2394, 0.2827, 0.1085, -0.3365, -0.6076, 0.3858, -0.0095, 0.1750, -0.5272, 0.6221, 0.1954, -0.4898, 0.0366, -0.1280, -0.0168, 0.2565, -0.3170, 0.4826, -0.1418, 0.1105, -0.3098, -0.6314, -0.3727, 0.2318, -0.1427, -0.0234, 0.0223, -0.0447, -0.1640, -0.2585, 0.1629, 0.0248, 0.2335, 0.2793, 0.3900, -0.0590, 0.1135, 0.1567, 0.1858, -0.1981, -0.4812, -0.0351, 0.0785, -0.4983, 0.1085, -0.2013, 0.0529, -0.1158, -0.1601, 0.1677, 0.4236, -0.2311, 0.0825, 0.2430, -0.1679, 0.0080, 0.0859, 0.3803, 0.0730, 0.1633, 0.2470, -0.1109, 0.1512, -0.2207, -0.0619, -0.0371, -0.0879, -0.2318, 0.1504, -0.1909, -0.1911, -0.1189, 0.0949, -0.0043, 0.1536, -0.4120, -0.3073, 0.1838, 0.4021, -0.0035, -0.1092, -0.6952, 0.1016, -0.0793, 0.4033, 0.2228, -0.1937, -0.1331, 0.0732, 0.0998, 0.1169, -0.2164, -0.1108, 0.1034, 0.0973, 0.1120, -0.3894, -0.0089, 0.2881, -0.1079, 0.0288, 0.3255, 0.2605, -0.0389, 0.0752, 0.4603, -0.0629, 0.2166, 0.1787, -0.5192, 0.3359])
vector = vectors["hi"]
print(vector.shape)
print(vector)
torch.Size([300]) tensor([ 2.8796e-02, 4.1306e-01, -4.6690e-01, -7.8175e-02, 3.7058e-01, 1.2867e-01, 4.7714e-01, -9.2372e-01, -6.7789e-02, 6.2381e-01, -2.9670e-01, -4.4328e-01, -8.4224e-02, -3.1270e-01, -1.8197e-01, 3.2360e-01, -7.7793e-02, 1.3314e+00, -1.5676e-01, 1.2857e-01, 4.3474e-02, 7.9883e-02, 1.1311e-02, 1.4428e-01, 1.7653e-01, -2.2321e-01, -4.2480e-02, 2.1707e-03, -4.7640e-02, 3.8532e-01, -5.9911e-02, 1.8338e-01, -1.9145e-01, -1.3184e-01, -2.2440e-01, -3.4313e-01, -1.9527e-01, 2.0129e-01, -2.8915e-01, -2.0750e-01, 1.9230e-01, -4.3318e-01, -3.5914e-02, -1.7492e-01, 5.1793e-03, 4.1998e-01, 1.0637e-01, 1.6559e-01, 2.8926e-01, 2.1868e-01, -7.7643e-02, 6.1037e-01, -1.7432e-02, -2.9676e-03, -3.0160e-01, -1.1983e-02, -9.4832e-02, 9.5424e-02, -3.7713e-01, -1.1239e-01, -7.8399e-01, -1.7278e-01, 4.9498e-02, -2.0969e-01, 3.1968e-01, -3.0732e-01, 1.0192e-01, 2.0580e-01, 3.2505e-01, -2.5291e-01, -9.3692e-02, 5.2662e-03, 4.5696e-01, -1.1763e-01, 2.6193e-01, 3.2966e-02, -4.7883e-03, 4.7738e-01, -3.3887e-02, 3.6247e-01, -1.9945e-01, 4.4342e-01, -3.7178e-01, 3.2319e-01, -1.1709e-01, -1.5551e-01, 1.4257e+00, -4.7203e-01, 2.4915e-01, 1.2907e-01, 1.3357e-01, -1.5880e-01, -3.0594e-01, -9.4597e-02, 1.3255e-01, -8.9818e-02, 5.0826e-01, -2.0685e-01, -6.9602e-01, 4.8778e-01, -1.4408e-01, 5.1481e-02, -1.6557e-02, 3.3421e-01, 6.7242e-02, -1.1685e-01, -4.6423e-02, -3.9958e-01, -3.1008e-01, -2.4609e-01, 5.8174e-02, -5.2140e-01, -6.0439e-02, 3.0534e-03, 4.1036e-01, 4.4092e-01, 2.8334e-01, -5.6422e-01, 5.5707e-02, 1.7791e-04, 3.4433e-01, -3.0717e-01, 2.5623e-01, -2.6241e-01, 4.0216e-01, 3.3964e-01, 5.5718e-01, 6.9994e-02, -1.6490e-01, 3.2947e-01, -6.9621e-02, -3.7227e-01, -1.0987e-01, -3.7106e-01, 4.0310e-01, -4.1511e-01, -9.0917e-02, 1.7001e-01, -3.4748e-01, -1.6285e-01, -2.3767e+00, -3.5290e-01, -8.5539e-02, -5.0965e-01, -1.5912e-01, 2.4123e-01, -2.0030e-01, 2.9155e-01, -3.3438e-01, -2.1440e-01, 1.0519e-01, -6.0930e-02, -3.5564e-01, -3.5314e-01, 1.1538e-01, 1.3500e-01, 3.1325e-01, -1.0790e-01, 2.4903e-01, -4.0942e-01, 1.9815e-01, 1.5635e-01, 4.5990e-01, -1.7499e-04, -8.7480e-02, 3.3567e-02, 1.2889e-01, -1.3793e-02, -1.3751e-01, 3.8376e-01, -4.8534e-01, 1.0498e-01, -3.0883e-01, -3.9634e-01, -7.5734e-02, -3.9470e-01, -3.3696e-01, 2.5969e-02, -2.9933e-02, 2.1998e-01, -3.6887e-01, -6.3065e-02, -4.5264e-01, -7.6559e-02, -9.0896e-02, -2.7469e-01, 2.3256e-01, -6.9002e-02, 8.1259e-02, -2.9682e-01, 5.0958e-01, -2.4812e-01, -4.1866e-01, -4.9677e-01, 5.3641e-02, -1.6098e-01, -1.3070e-01, -2.1058e-01, 7.0593e-01, 4.1502e-01, -2.9617e-01, -1.2387e-01, 1.9504e-02, -2.1288e-01, 6.1103e-02, -4.0131e-01, 5.3975e-01, -3.7639e-01, -1.8536e-01, -3.6357e-01, -4.5547e-01, -9.0210e-02, -2.0425e-01, -2.4413e-01, -8.1124e-02, 2.4698e-02, 7.5438e-02, 6.2125e-03, 1.6757e-01, -2.1207e-01, -6.1182e-02, 4.4722e-01, 4.1641e-01, 8.2606e-01, -6.2413e-03, 5.5281e-01, -1.5134e-01, -7.9939e-02, 8.4223e-02, -3.3734e-01, -2.3321e-02, -2.9588e-01, -9.3586e-01, 3.2397e-01, -2.4314e-01, -2.0533e-02, -5.2084e-01, 5.2986e-02, -1.1679e-01, 4.7422e-01, -1.8861e-01, 2.8550e-01, 5.2586e-01, 3.2893e-01, 3.1098e-01, -1.5665e-01, 5.8859e-01, 1.2991e-01, 4.8790e-01, 7.1808e-02, 1.3260e-01, 1.6146e-01, -4.9939e-01, -1.5210e-01, 7.6596e-02, 3.7449e-01, -1.8812e-01, 1.6209e-01, -3.0729e-01, -3.8459e-01, -8.6934e-02, -3.4415e-01, 2.1309e-01, 9.9894e-02, -9.2105e-01, -2.6550e-01, 5.2581e-03, 8.0952e-01, -2.6002e-01, -1.9374e-01, -4.7203e-01, 4.0053e-01, -1.3437e-01, 2.0369e-01, 1.2778e-01, -5.7577e-02, 2.4322e-03, -4.2885e-02, -8.2562e-02, 2.7829e-01, 4.3434e-02, -9.3094e-02, -3.0028e-01, 1.9869e-01, -2.7712e-02, -2.8615e-01, 5.4265e-02, 1.7516e-01, 9.4575e-02, 4.7020e-01, 3.6270e-01, -2.0331e-01, -3.2928e-01, -4.8915e-02, 6.3414e-01, -1.1668e-01, 2.0476e-01, -5.3029e-02, -3.3494e-01, 3.6282e-01])
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121707481
以上是关于[Pytorch系列-56]:循环神经网络 - word2vec词向量表Embedding/Glove的定义与读访问的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-58]:循环神经网络 - 词向量的自动构建与模型训练代码示例
[Pytorch系列-55]:循环神经网络 - 使用LSTM网络对股票走势进行预测
[Pytorch系列-53]:循环神经网络 - torch.nn.LSTM()参数详解