浅显理解torchtext对文本预处理的过程
Posted Icy Hunter
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了浅显理解torchtext对文本预处理的过程相关的知识,希望对你有一定的参考价值。
这里处理数据是面对英文文本及分类问题的,中文应该也差不多。
数据集:https://download.csdn.net/download/qq_52785473/78428834?spm=1001.2014.3001.5503
0积分下载。
下面案例主要是对一个数据集进行处理,作为跑流程的实例,完整的流程应该是三个数据集,其中训练集需要构建字典,其他不用构建字典,最后训练的时候直接调用迭代就可以进行训练了。
下面代码其实是学长给的代码,我涂涂改改得来的,毕竟水平不太够,只能用自己的理解来解释解释以便自己能够看的懂一点。
import pandas as pd
import os
import torchtext
from tqdm import tqdm
class mydata(object):
def __init__(self):
self.n_class = 2
self.data_dir = '../datasets'
def _generator(self, filename): # 读取文件,取一个用一个比较节省内存
path = os.path.join(self.data_dir, filename) # 拼接路径
df = pd.read_csv(path, sep='\\t', header=None) # 读取文件
for index, line in df.iterrows(): # 取出句子和对应标签
sentence = line[0]
label = line[1]
yield sentence, label
def load_train_data(self): # 文件路径
return self._generator('')
def load_dev_data(self):
return self._generator('dev.tsv')
def load_test_data(self):
return self._generator('')
class Dataset(object):
def __init__(self, dataset: mydata, batch_size, fix_length):
self.dataset = dataset # 使得能够调用mydata这个类
self.batch_size = batch_size
self.fix_length = fix_length
def load_data(self):
tokenizer = lambda sentence: [x for x in sentence.split() if x != ' '] # 定义切词工具,以空格切词且过滤空格
# fix_length为截断,多退少补0
# sequential为判断是否为序列,文本为True
# 全都转换成小写
TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=self.fix_length)
# 标签不是序列为false,且不需要创建词典
LABEL = torchtext.data.Field(sequential=False, use_vocab=False)
# tex, label(自己定义)能取出example对应的数据
# Field相当于定义了一种数据类型吧
datafield = [("text", TEXT), ("label", LABEL)]
dev_gen = self.dataset.load_dev_data() # 加载数据
dev_example = [torchtext.data.Example.fromlist(it, datafield) for it in tqdm(dev_gen)] # 封装成example对象,数据+标签,tqdm为进度条
# dev_example = [torchtext.data.Example.fromlist(it, datafield) for it in dev_gen] # 只执行这个就没有进度条了,tqdm用法就明确了
print(dev_example[0].text) # 取出第一句话的text
dev_data = torchtext.data.Dataset(dev_example, datafield) # 转换成dataset类
# 验证集其实不用字典但是因为我上传的数据为验证集的数据,因此就拿验证集作为例子了
TEXT.build_vocab(dev_data) # 创建字典且将词向量化
self.vocab = TEXT.vocab
# BucketIterator创建一个迭代器,使得数据可迭代,方便训练的时候调用
self.dev_iterator = torchtext.data.BucketIterator(
(dev_gen),
batch_size=self.batch_size,
sort_key=lambda x: len(x.text),
shuffle=False
)
print(f"load len(dev_data) dev examples")
if __name__ == '__main__':
data_class = mydata()
datasets = Dataset(data_class, 10, 10)
datasets.load_data()
以上是关于浅显理解torchtext对文本预处理的过程的主要内容,如果未能解决你的问题,请参考以下文章