BERT-多标签文本分类实战之四——数据集预处理
Posted 征途黯然.
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了BERT-多标签文本分类实战之四——数据集预处理相关的知识,希望对你有一定的参考价值。
·请参考本系列目录:【BERT-多标签文本分类实战】之一——实战项目总览
·下载本实战项目资源:>=点击此处=<
[1] 数据集预处理的流程
在拿到数据集之后,我们关心接下来操作的步骤:
· 查看数据集的基本数据
· 分析数据集的标签构成
· 数据集拆分成训练集、验证集、测试集
· 处理数据集的文本数据(首先了解bert
模型的输入)
为什么要这样做?
对于一个新数据集,我们需要:1)把文本转化成嵌入向量;2)把文本的标签转化成独热数组;3)拆分数据集。除此之外,我们还应该关心数据集平均一个本文有几个标签、最多有几个标签、最少有几个标签,哪些标签出现的比较频繁,这些有助于我们加深对数据集、任务难点的了解。
【注】
1、嵌入向量:在使用预训练词向量的时候,首先会构建一个词典,然后把文本里面的每个单词逐个转化成词典里面对应的序号,最后根据序号再去预训练词向量里面找对应单词的d维向量,于是一条文本就变成了pad_size * d
的嵌入向量。
2、独热数组:是由0-1构成的数组。假设数据集中有5个标签a,b,c,d,e
,文本的实际标签是b,c
,那么它的独热数组就是[0,1,1,0,0]
。
[2] 查看数据集的基本数据
数据集Rruters-21578
的初始数据是放在xlsx
文件中的,如下图。
我们观察到该数据集:
· 共有10788条数据,其中第2-3020条是test
数据,第3021-10789条是training
数据,需要进一步划分;
· 标签是以数组转成字符串形式存储:['grain', 'rice']
;
· 文本以英文形式存储,里面有大小写、特殊字符等,需要我们额外处理一下。
[3] 分析数据集的标签构成
首先读取数据集,利用正则表达式把每个文本的标签提取出来,查看一共有多少标签:
dataset_path = r'@_数据集_原始/Reuters-21578/reutersNLTK.xlsx'
df = pd.read_excel(dataset_path, encoding='utf-8', sep=',')
result_single = // 统计每个标签出现的次数
result_multi =
a = [] // 存放每个文本的标签,可重复,一维数组
b = [] // 存放每个文本的标签数量,一维数组
# 提取标签
for ls in df['categories']:
b.append(len(re.compile(r"'(.*?)'").findall(ls)))
for l in re.compile(r"'(.*?)'").findall(ls):
a.append(l)
# 计算标签次数
for i in set(a):
result_single[i] = a.count(i)
# 倒叙排列后输出
label_list = sorted([_ for _ in result_single.items()], key=lambda x: x[1], reverse=True)
print(label_list)
输出结果: [('earn', 3964), ('acq', 2369), ('money-fx', 717), ('grain', 582), ('crude', 578), ('trade', 485), ('interest', 478), ('ship', 286), ('wheat', 283), ('corn', 237), ('dlr', 175), ('money-supply', 174), ('oilseed', 171), ('sugar', 162), ('coffee', 139), ('gnp', 136), ('gold', 124), ('veg-oil', 124), ('soybean', 111), ('nat-gas', 105), ('bop', 105), ('livestock', 99), ('cpi', 97), ('cocoa', 73), ('reserves', 73), ('carcass', 68), ('jobs', 67), ('copper', 65), ('cotton', 59), ('yen', 59), ('rice', 59), ('alum', 58), ('gas', 54), ('iron-steel', 54), ('ipi', 53), ('barley', 51), ('meal-feed', 49), ('rubber', 49), ('palm-oil', 40), ('sorghum', 34), ('zinc', 34), ('pet-chem', 32), ('tin', 30), ('silver', 29), ('wpi', 29), ('lead', 29), ('rapeseed', 27), ('strategic-metal', 27), ('orange', 27), ('soy-meal', 26), ('soy-oil', 25), ('retail', 25), ('fuel', 23), ('hog', 22), ('housing', 20), ('heat', 19), ('sunseed', 16), ('lumber', 16), ('income', 16), ('lei', 15), ('oat', 14), ('dmk', 14), ('tea', 13), ('platinum', 12), ('groundnut', 9), ('nickel', 9), ('rape-oil', 8), ('l-cattle', 8), ('coconut-oil', 7), ('sun-oil', 7), ('instal-debt', 6), ('potato', 6), ('propane', 6), ('naphtha', 6), ('coconut', 6), ('jet', 5), ('nzdlr', 4), ('cpu', 4), ('palladium', 3), ('nkr', 3), ('dfl', 3), ('copra-cake', 3), ('cotton-oil', 3), ('palmkernel', 3), ('rand', 3), ('lin-oil', 2), ('castor-oil', 2), ('sun-meal', 2), ('groundnut-oil', 2), ('rye', 2)]
可以看到,10788条文本数据中,标签为earn
的文本有3964个。接下来我们查看标签个数、文本标签数目:
label_key = [word_count[0] for idx, word_count in enumerate(label_list)]
label_val = [word_count[1] for idx, word_count in enumerate(label_list)]
print("标签个数:", len(label_key))
print("文本标签数目:", set(b))
print("每个文本平均有标签:", sum(label_val) / len(df['categories']))
输出结果:
标签个数: 90
文本标签数目: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15
每个文本平均有标签: 1.235446792732666
可以看到,每个文本平均有1.2个标签,相对其他数据集来说,是极低的。数据集中文本标签个数为1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15的都有,我们详细查看它们的分布:
k =
# 计算共现标签
for i in set(b):
k[i] = b.count(i)
print(k)
输出结果:1: 9160, 2: 1173, 3: 255, 4: 91, 5: 52, 6: 27, 7: 9, 8: 7, 9: 5, 10: 3, 11: 2, 12: 1, 14: 2, 15: 1
有9160条文本只有1个标签,有1173条文本是2个标签,再往后就很稀疏了。
【注】如何提高模型对这些稀疏数据的预测性能,这也是多标签文本分类中的研究难题。
[4] 数据集拆分成训练集、验证集、测试集
我们把数据集中3019条test,分为1500条验证文本和1519条测试文本。在拆分过程中记得要打乱顺序:
import random
import pandas as pd
dataset_path =r'@_数据集_原始/Reuters-21578/reutersNLTK.xlsx'
df = pd.read_excel(dataset_path, encoding='utf-8', sep=',')
""" 多标签数据集划分 Reuter-21578
label_name : 标签列名
content_name : 内容列名
export_name : 导出的文件名
proportion : 训练-验证-测试比例, 数组, 加和为10, eg: [7,1.5,1.5]
"""
def divideDataSet(label_name, content_name, export_name, proportion):
contents = df[content_name]
labels = df[label_name]
ids = df['ids']
train, val, test = [], [], []
for i, id in enumerate(ids):
if id[:3] == 'tra':
train.append('label': labels[i], 'content': contents[i])
else:
val.append('label': labels[i], 'content': contents[i])
random.shuffle(train)
random.shuffle(val)
test = val[:1500]
val = val[1500:]
print(len(train), len(val), len(test))
random.shuffle(train)
random.shuffle(val)
random.shuffle(test)
train.insert(0, 'label': 'label', 'content': 'content')
val.insert(0, 'label': 'label', 'content': 'content')
test.insert(0, 'label': 'label', 'content': 'content')
with open(export_name + r'train.csv', 'a', newline='', encoding='utf-8') as f:
xieru = csv.DictWriter(f, ['label', 'content'], delimiter=',')
xieru.writerows(train) # writerows方法是一下子写入多行内容
with open(export_name + r'dev.csv', 'a', newline='', encoding='utf-8') as f:
xieru = csv.DictWriter(f, ['label', 'content'], delimiter=',')
xieru.writerows(val) # writerows方法是一下子写入多行内容
with open(export_name + r'test.csv', 'a', newline='', encoding='utf-8') as f:
xieru = csv.DictWriter(f, ['label', 'content'], delimiter=',')
xieru.writerows(test) # writerows方法是一下子写入多行内容
divideDataSet('categories', 'text', root_path + r'@_数据集_已处理/Reuters-21578/data/', [])
拆分出来3个csv文件。
[5] 处理数据集的文本数据
在处理数据集的文本数据前,有必要了解一下使用预处理词向量的模型,是如何处理文本数据的。请参考博客:【英文文本分类实战】之三——数据清洗。
由于bert
模型的强大,我们甚至可以不作任何处理,就能直接把文本作为输入放到bert
模型中,因为bert
模型自带词典(很强大)。我倾向于不作处理,所以这里不再叙述。但其实如果处理的话,也不过是转化一下大小写、过滤特殊字符、过滤缩写,这些在【英文文本分类实战】之三——数据清洗都有详细介绍。
以上是关于BERT-多标签文本分类实战之四——数据集预处理的主要内容,如果未能解决你的问题,请参考以下文章