在数据增强蒸馏剪枝下ERNIE3.0分类模型性能提升

Posted 汀、

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了在数据增强蒸馏剪枝下ERNIE3.0分类模型性能提升相关的知识,希望对你有一定的参考价值。

在数据增强、蒸馏剪枝下ERNIE3.0模型性能提升

项目链接:
https://aistudio.baidu.com/aistudio/projectdetail/4436131?contributionType=1

以CBLUE数据集中医疗搜索检索词意图分类为例:

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,结果预览:

模型ACCPrecisionRecallF1average_of_acc_and_f1
ERNIE 3.0 Base0.802550.93171470.9082840.9198500.86120
ERNIE 3.0 Base+数据增强0.79795390.9010040.928990.914780.8563
ERNIE 3.0 Base+剪裁保留比0.50.798460.9512570.894970.922250.8603
ERNIE 3.0 Base +剪裁保留比2/30.80920710.94153840.9053250.9230760.86614

gensim安装最新版本:pip install gensim

tqdm安装:pip install tqdm

LAC安装最新版本:pip install lac


Gensim库介绍

Gensim是在做自然语言处理时较为经常用到的一个工具库,主要用来以无监督的方式从原始的非结构化文本当中来学习到文本隐藏层的主题向量表达。

主要包括TF-IDF,LSA,LDA,word2vec,doc2vec等多种模型。

Tqdm

是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。目的为了程序显示的美观

中文词法分析-LAC

LAC是一个联合的词法分析模型,整体性地完成中文分词、词性标注、专名识别任务。LAC既可以认为是Lexical Analysis of Chinese的首字母缩写,也可以认为是LAC Analyzes Chinese的递归缩写。

LAC基于一个堆叠的双向GRU结构,在长文本上准确复刻了百度AI开放平台上的词法分析算法。效果方面,分词、词性、专名识别的整体准确率95.5%;单独评估专名识别任务,F值87.1%(准确90.3,召回85.4%),总体略优于开放平台版本。在效果优化的基础上,LAC的模型简洁高效,内存开销不到100M,而速度则比百度AI开放平台提高了57%

LAC链接:https://www.paddlepaddle.org.cn/modelbasedetail/lac

!pip install --upgrade paddlenlp
!pip install gensim
!pip install tqdm
!pip install lac

2.数据增强方案介绍

数据增强工具提供4种增强策略:遮盖、删除、同词性词替换、词向量近义词替换

!unzip ERNIE-.zip -d ./ERNIE
#添加ERNIE工具包

如果程序报错:
可以发现提示有一个.ipynb_checkpoints的文件。但当我去对应的文件夹找时根本看不到这个文件,所以猜测是一个隐藏文件。所以通过终端进入对应的目录:输入cd coco进入对应目录,输入ls -a显示所有文件。然后输入rm -rf .ipynb_checkpoints删除该文件。再次输入ls -a查看文件是否被删除。

下载词表,词表有1.7G会花点时间。下面以情感分析数据样例展示demo,看看数据增强的效果。

!wget -q --no-check-certificate http://bj.bcebos.com/wenxin-models/vec2.txt

python data_aug.py “输入文件夹的目录” “输出文件夹的目录”

  • data_aug.py脚本传参说明
shell输入:
    python data_aug.py -h

shell输出:
    usage: data_aug.py [-h] [-n AUG_TIMES] [-c COLUMN_NUMBER] [-u UNK]
                       [-t TRUNCATE] [-r POS_REPLACE] [-w W2V_REPLACE]
                       [-e ERNIE_REPLACE] [--unk_token UNK_TOKEN]
                       input output
    
    main
    
    positional arguments:
      input                                                #原始待增强数据文件所在文件夹,带label的,一个或多个文本列
      output                                               #输出文件路径
    
    optional arguments:
      -h, --help            show this help message and exit
      -n AUG_TIMES, --aug_times AUG_TIMES                  #数据集数目放大n倍,output行数为input的n+1倍      
      -c COLUMN_NUMBER, --column_number COLUMN_NUMBER      #明文文件中所要增强列的列序号,多列用逗号分割,如:1,2
      -u UNK, --unk UNK                                    #unk 增强策略的概率
      -t TRUNCATE, --truncate TRUNCATE                     #truncate 增强策略的概率
      -r POS_REPLACE, --pos_replace POS_REPLACE            #pos_replace 增强策略的概率
      -w W2V_REPLACE, --w2v_replace W2V_REPLACE            #w2v_replace 增强策略的概率
      --unk_token UNK_TOKEN                    

分类问题中:推荐使用前三种即可,w2v词向量近义词替换可以不用,花费时间太长。

!python data_aug.py --unk 0.25 --truncate 0.25 --pos 0.5 --w2v 0 ./train ./output
demo结果展示:

机器 背面 似乎 被 撕 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪	0
机器 背面 似乎 被 撕 了 张 什么 标签 , 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪	0
机器 背面 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签  了 , 该在 , 怪	0
呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。	0
呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我😄妈 爱 看 , 我自己 也 学 着 找 一些 穴位 😄	0
呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还😄 能 看得出来 是 盗😄😄😄。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 ,😄😄😄😄😄😄😄学 着 找 😄😄😄😄😄😄😄	0
😄😄😄😄😄虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。	0
😄😄😄😄😄😄😄 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。	0
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近。。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 机器 还算 干净 , 离 湖南路小吃街 近 。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽
我 看 是 书 的 还 可以 , 但是 我 订 的 书 迟迟 还 到 能 半个月 , 都 没有 收到 打电话 也 没

3.数据蒸馏技术

ERNIE数据蒸馏三步

Step 1. 使用ERNIE模型对输入标注数据对进行fine-tune,得到Teacher Model

Step 2. 使用ERNIE Service对以下无监督数据进行预测:

  • 用户提供的大规模无标注数据,需与标注数据同源
  • 对标注数据进行数据增强,具体增强策略
  • 对无标注数据和数据增强数据进行一定比例混合

Step 3. 使用步骤2的数据训练出Student Model

数据增强

目前采用三种数据增强策略策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:

添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签

同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集钟随机一个同词性的词

N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值
模型剪裁,基于 PaddleNLP 的 Trainer API 发布提供了模型裁剪 API。裁剪 API 支持用户对 ERNIE 等Transformers 类下游任务微调模型进行裁剪。

具体效果在下一节展现,先安装好paddleslim库

4.基于ERNIR3.0文本模型微调

加载已有数据集:CBLUE数据集中医疗搜索检索词意图分类(训练)

数据集定义:
以公开数据集CBLUE数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为示例,在训练集上进行模型微调,并在开发集上使用准确率Accuracy评估模型表现。

数据集默认为:默认为"cblue"。

save_dir:保存训练模型的目录;默认保存在当前目录checkpoint文件夹下。

dataset:训练数据集;默认为"cblue"。

dataset_dir:本地数据集路径,数据集路径中应包含train.txt,dev.txt和label.txt文件;默认为None。

task_name:训练数据集;默认为"KUAKE-QIC"。

max_seq_length:ERNIE模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。

model_name:选择预训练模型;默认为"ernie-3.0-base-zh"。

device: 选用什么设备进行训练,可选cpu、gpu、xpu、npu。如使用gpu训练,可使用参数gpus指定GPU卡号。

batch_size:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。

learning_rate:Fine-tune的最大学习率;默认为6e-5。

weight_decay:控制正则项力度的参数,用于防止过拟合,默认为0.01。

early_stop:选择是否使用早停法(EarlyStopping);默认为False。

early_stop_nums:在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为4。
epochs: 训练轮次,默认为100。

warmup:是否使用学习率warmup策略;默认为False。

warmup_proportion:学习率warmup策略的比例数,如果设为0.1,则学习率会在前10%steps数从0慢慢增长到learning_rate, 而后再缓慢衰减;默认为0.1。

logging_steps: 日志打印的间隔steps数,默认5。

init_from_ckpt: 模型初始checkpoint参数地址,默认None。

seed:随机种子,默认为3。

#修改后的训练文件train_new2.py ,主要使用了paddlenlp.metrics.glue的AccuracyAndF1:准确率及F1-score,可用于GLUE中的MRPC 和QQP任务
#不过吐槽一下:    return (acc,precision,recall,f1,(acc + f1) / 2,) 最后一个指标竟然是加权平均.....
!python train_new2.py --warmup --early_stop --epochs 10 --save_dir "./checkpoint2" --batch_size 16 --model_name ernie-3.0-base-zh

训练结果部分展示:

[2022-08-16 19:58:36,834] [    INFO] - global step 1280, epoch: 3, batch: 412, loss: 0.23292, acc: 0.87106, speed: 16.54 step/s
[2022-08-16 19:58:37,392] [    INFO] - global step 1290, epoch: 3, batch: 422, loss: 0.22339, acc: 0.87130, speed: 17.94 step/s
[2022-08-16 19:58:37,960] [    INFO] - global step 1300, epoch: 3, batch: 432, loss: 0.22791, acc: 0.87182, speed: 17.68 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8025575447570332, 0.9317147192716236, 0.908284023668639, 0.9198501872659175, 0.8612038660114754)

[2022-08-16 20:01:36,060] [ INFO] - Early stop!
[2022-08-16 20:01:36,060] [ INFO] - Save best accuracy text classification model in ./checkpoint2

4.1 加载自定义数据集(并通过数据增强训练)

从本地文件创建数据集

使用本地数据集来训练我们的文本分类模型,本项目支持使用固定格式本地数据集文件进行训练
如果需要对本地数据集进行数据标注,可以参考文本分类任务doccano数据标注使用指南进行文本分类数据标注。[这个放到下个项目讲解]

本项目将以CBLUE数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为例进行介绍如何加载本地固定格式数据集进行训练:

本地数据集目录结构如下:

data/
├── train.txt # 训练数据集文件
├── dev.txt # 开发数据集文件
├── label.txt # 分类标签文件
└── data.txt # 可选,待预测数据文件

部分结果展示

[2022-08-16 23:43:18,093] [    INFO] - global step 2400, epoch: 2, batch: 234, loss: 0.60859, acc: 0.84437, speed: 19.27 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7979539641943734, 0.9010043041606887, 0.9289940828402367, 0.9147851420247632, 0.8563695531095683)
[2022-08-16 23:43:24,522] [    INFO] - Save best F1 text classification model in ./checkpoint3
[2022-08-16 23:43:24,523] [    INFO] - best F1 performence has been updated: 0.91450 --> 0.91479

4.2 数据蒸馏

!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \\
    --device "gpu" \\
    --output_dir "./prune" \\
    --per_device_train_batch_size 32 \\
    --per_device_eval_batch_size 32 \\
    --learning_rate 3e-5 \\
    --num_train_epochs 5 \\
    --logging_steps 10 \\
    --save_steps 50 \\
    --seed 3 \\
    --dataset_dir "KUAKE_QIC" \\
    --max_seq_length 128 \\
    --params_dir "./checkpoint3" \\
    --width_mult '0.5'

部分结果展示:

[2022-08-17 14:22:30,954] [    INFO] - width_mult: 0.5, eval loss: 0.63535, acc: 0.79847
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7984654731457801, 0.9512578616352201, 0.8949704142011834, 0.9222560975609755, 0.8603607853533778)
[2022-08-17 14:22:35,870] [    INFO] - Save best F1 text classification model in ./prune/0.5
[2022-08-17 14:22:35,870] [    INFO] - best F1 performence has been updated: 0.92226 --> 0.92226
!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \\
    --device "gpu" \\
    --output_dir "./prune" \\
    --per_device_train_batch_size 32 \\
    --per_device_eval_batch_size 32 \\
    --learning_rate 3e-5 \\
    --num_train_epochs 5 \\
    --logging_steps 10 \\
    --save_steps 50 \\
    --seed 3 \\
    --dataset_dir "KUAKE_QIC" \\
    --max_seq_length 128 \\
    --params_dir "./checkpoint3" \\
    --width_mult '2/3'
2022-08-17 14:53:45,544] [    INFO] - global step 3070, epoch: 2, batch: 904, loss: 0.709566, speed: 9.93 step/s
[2022-08-17 14:53:46,550] [    INFO] - global step 3080, epoch: 2, batch: 914, loss: 0.607238, speed: 9.94 step/s
[2022-08-17 14:53:47,558] [    INFO] - global step 3090, epoch: 2, batch: 924, loss: 0.718484, speed: 9.93 step/s
[2022-08-17 14:53:48,563] [    INFO] - global step 3100, epoch: 2, batch: 934, loss: 0.546288, speed: 9.95 step/s
[2022-08-17 14:53:50,206] [    INFO] - teacher model, eval loss: 0.66438, acc: 0.80358
[2022-08-17 14:53:50,207] [    INFO] - eval done total : 1.6434180736541748 s
[2022-08-17 14:53:53,568] [    INFO] - width_mult: 0.6666666666666666, eval loss: 0.60219, acc: 0.80921
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8092071611253197, 0.9415384615384615, 0.9053254437869822, 0.923076923076923, 0.8661420421011213)
[2022-08-17 14:53:58,489] [    INFO] - Save best F1 text classification model in ./prune/0.6666666666666666
[2022-08-17 14:53:58,489] [    INFO] - best F1 performence has been updated: 0.92308 --> 0.92308

4.3 模型预测

输入待预测数据和数据标签对照列表,模型预测数据对应的标签

使用默认数据进行预测:

#也可以选择使用本地数据文件data/data.txt进行预测:
!python predict.py --params_path ./checkpoint3/ --dataset_dir ./KUAKE_QIC --device "cpu"
黑苦荞茶的功效与作用及食用方法 功效作用
交界痣会凸起吗 疾病表述
检查是否能怀孕挂什么科 就医建议
鱼油怎么吃咬破吃还是直接咽下去 其他
幼儿挑食的生理原因是 病因分析
!python predict.py \\
    --device "cpu" \\
    --dataset_dir ./KUAKE_QIC \\
    --params_path "./prune/0.5" \\

5.总结

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,现在进行汇总

模型ACCPrecisionRecallF1average_of_acc_and_f1
ERNIE 3.0 Base0.802550.93171470.9082840.9198500.86120
ERNIE 3.0 Base+数据增强0.79795390.9010040.928990.914780.8563
ERNIE 3.0 Base+剪裁保留比0.50.798460.9512570.894970.922250.8603
ERNIE 3.0 Base +剪裁保留比2/30.80920710.94153840.9053250.9230760.86614

分析可得,

  • 首先数据增强后导致性能部分下降部分和预期的原因:
    随机mask、删除会产生过多噪声样本影响结果,推荐只使用同义词替换,本次样本数据量足够,且ERNIE性能本就优越,数据增强对结果提升在较大样本集可以忽略。

  • 其次,可以看到通过数据蒸馏后,模型性能变化不大,甚至在剪裁1/3之后,性能有小幅度提升

本次主要对分类模型加入数据增强、数据蒸馏,已经对性能指标进行细化,不只是ACC,个人比较关注F1情况,并作为保存模型依据。

展望: 后续将完善动态图和静态图转化部分,让蒸馏下来模型可以继续线上加载使用;其次将会考虑小样本学习在分类模型应用情况;最后将完成模型融合环节提升性能,并做可解释性分析。

本人博客:https://blog.csdn.net/sinat_39620217?type=blog

以上是关于在数据增强蒸馏剪枝下ERNIE3.0分类模型性能提升的主要内容,如果未能解决你的问题,请参考以下文章

《模型轻量化-剪枝蒸馏量化系列》YOLOv5无损剪枝(附源码)

《模型轻量化-剪枝蒸馏量化系列》YOLOv5无损剪枝(附源码)

知识蒸馏轻量化模型架构剪枝…几种深度学习模型压缩方法

深度学习之模型优化模型剪枝模型量化知识蒸馏概述

PaddleHub实战篇{ERNIE实现文新闻本分类ERNIE3.0 实现序列标注}

知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型