基于BERT模型的文本分类研究 TensorFlow2实现(内附源码)自然语言处理NLP-100例

Posted K同学啊

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于BERT模型的文本分类研究 TensorFlow2实现(内附源码)自然语言处理NLP-100例相关的知识,希望对你有一定的参考价值。


大家好,我是K同学啊!

在上一篇文章中讲解了BERT是什么,这一篇文章中我应用BERT进行一个文本分类的实战,使用的是THUCTC数据集,实现了财经房产股票教育科技社会时政体育游戏娱乐等10类文本的高效分类,最后的分类准确率达到了83.3%,项目情况如下:

自然语言处理100例:基于BERT模型的文本分类

一、导入并整理数据

data_path = "./5-data/data.txt" 
model_path = "bert-base-chinese"

max_length       = 32
batch_size       = 128
learning_rate    = 2e-5
num_classes      = 10 # 类别数

# 准备数据
df_raw = pd.read_csv(data_path,sep="\\t",header=None,names=["text","label"])    

class_names = ["财经","房产","股票","教育","科技","社会","时政","体育","游戏","娱乐"]
# 标签数字化
df_label = pd.DataFrame("label":class_names,"y":list(range(10)))

df_raw = pd.merge(df_raw,df_label,on="label",how="left")
df_raw.head(3)
textlabely
0中华女子学院:本科层次仅1专业招男生教育3
1两天价网站背后重重迷雾:做个网站究竟要多少钱科技4
2东5环海棠公社230-290平2居准现房98折优惠房产1

查看数据每一个类别的比例

# 源码内可阅读
plt.show()

查看数据长度分布

# 源码内可阅读
plt.show()

# 源码内可阅读
plt.show()
分位点为0.9的句子长度:24。

二、数据集划分

train_data, x = train_test_split(df_raw, 
                                 stratify=df_raw['label'], #按照df['label']中的类别比例分配
                                 test_size=0.1, 
                                 random_state=42)

val_data, test_data = train_test_split(x, 
                                       stratify=x['label'], 
                                       test_size=0.5, 
                                       random_state=43)
train_data.head(3)
textlabelylength
603美国普利策奖全部奖项将允许网络媒体参选时政619
237309考研强化复习策略:提纲挈领 变中稳进教育320
1759嫌犯仰仗家中藏獒拒捕 并以自残相威胁社会518
# tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)

# 调整数据格式
def map_example_to_dict(input_ids, attention_masks, token_type_ids, label):
    return 
      "input_ids": input_ids,
      "token_type_ids": token_type_ids,
      "attention_mask": attention_masks,
  , label

def encode_examples(ds):
    input_ids_list      = []
    token_type_ids_list = []
    attention_mask_list = []
    label_list          = []
    for index, row in ds.iterrows():
        bert_input = tokenizer.encode_plus(row["text"], 
                                    add_special_tokens = True, # add [CLS], [SEP]
                                    max_length = max_length, # max length of the text that can go to BERT
                                    pad_to_max_length = True, # add [PAD] tokens
                                    return_attention_mask = True, # add attention mask to not focus on pad tokens
                                    truncation=True)

        input_ids_list.append(bert_input['input_ids'])
        token_type_ids_list.append(bert_input['token_type_ids'])
        attention_mask_list.append(bert_input['attention_mask'])
        label_list.append(row["y"])
    return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict)

三、构建模型

# 配置数据集
ds_train_encoded = encode_examples(train_data).shuffle(10000).batch(batch_size)
ds_val_encoded   = encode_examples(val_data).batch(batch_size)
ds_test_encoded  = encode_examples(test_data).batch(batch_size)

# 初始化模型
model = TFBertForSequenceClassification.from_pretrained(model_path, num_labels=num_classes)
# 设置优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,epsilon=1e-08, clipnorm=1)

# 关于Loss不清楚的可以参考文章:https://mtyjkh.blog.csdn.net/article/details/122309754
loss   = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

model.compile(optimizer=optimizer, 
              loss=loss, 
              metrics=[metric])
All model checkpoint layers were used when initializing TFBertForSequenceClassification.

Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

四、训练模型

# fit model
bert_history = model.fit(ds_train_encoded, epochs=10, validation_data=ds_val_encoded)
Epoch 1/10
22/22 [==============================] - 20s 277ms/step - loss: 1.9059 - accuracy: 0.4585 - val_loss: 1.2277 - val_accuracy: 0.7933
Epoch 2/10
22/22 [==============================] - 4s 204ms/step - loss: 0.9633 - accuracy: 0.8230 - val_loss: 0.6662 - val_accuracy: 0.8467
Epoch 3/10
22/22 [==============================] - 5s 204ms/step - loss: 0.5279 - accuracy: 0.8900 - val_loss: 0.5360 - val_accuracy: 0.8600
Epoch 4/10
22/22 [==============================] - 5s 205ms/step - loss: 0.3482 - accuracy: 0.9200 - val_loss: 0.4698 - val_accuracy: 0.8667
Epoch 5/10
22/22 [==============================] - 5s 204ms/step - loss: 0.2514 - accuracy: 0.9448 - val_loss: 0.4263 - val_accuracy: 0.8867
Epoch 6/10
22/22 [==============================] - 5s 205ms/step - loss: 0.1654 - accuracy: 0.9689 - val_loss: 0.4706 - val_accuracy: 0.8800
Epoch 7/10
22/22 [==============================] - 5s 205ms/step - loss: 0.1139 - accuracy: 0.9841 - val_loss: 0.4517 - val_accuracy: 0.8867
Epoch 8/10
22/22 [==============================] - 4s 204ms/step - loss: 0.0841 - accuracy: 0.9863 - val_loss: 0.4967 - val_accuracy: 0.8933
Epoch 9/10
22/22 [==============================] - 5s 205ms/step - loss: 0.0684 - accuracy: 0.9878 - val_loss: 0.4540 - val_accuracy: 0.8933
Epoch 10/10
22/22 [==============================] - 5s 204ms/step - loss: 0.0493 - accuracy: 0.9948 - val_loss: 0.5542 - val_accuracy: 0.8867oss: 0.0481 - accura

五、模型评估

# evaluate test_set
test_loss, test_accuracy = model.evaluate(ds_test_encoded)
print("test_set loss:", test_loss)
print("test_set accuracy:", test_accuracy)
2/2 [==============================] - 0s 28ms/step - loss: 0.6915 - accuracy: 0.8333
test_set loss: 0.691510796546936
test_set accuracy: 0.8333333134651184

1. Loss图与Accourcy图

# 源码内可阅读
plt.show()

2. 其他评价参数

# 源码内可阅读
test_accuracy_report(model)
              precision    recall  f1-score   support

          体育       1.00      0.93      0.97        15
          娱乐       0.89      1.00      0.94        16
          房产       0.71      0.86      0.77        14
          教育       1.00      0.93      0.96        14
          时政       0.87      0.76      0.81        17
          游戏       1.00      0.73      0.85        15
          社会       0.75      0.94      0.83        16
          科技       0.92      0.80      0.86        15
          股票       0.62      0.57      0.59        14
          财经       0.69      0.79      0.73        14

    accuracy                           0.83       150
   macro avg       0.84      0.83      0.83       150
weighted avg       0.85      0.83      0.83       150

Loss function: 0.691510796546936, accuracy: 0.8333333134651184

3. 混淆矩阵

# 源码内可阅读
plot_cm(test_label, test_pre)


源码领取地址https://mp.weixin.qq.com/s/6K0ZInHfq-2acvbUwl5u3w

以上是关于基于BERT模型的文本分类研究 TensorFlow2实现(内附源码)自然语言处理NLP-100例的主要内容,如果未能解决你的问题,请参考以下文章

多标签文本分类MSML-BERT模型的层级多标签文本分类方法研究

基于BERT的多模型融合借鉴

文本分类基于BERT预训练模型的灾害推文分类方法基于BERT和RNN的新闻文本分类对比

[Python人工智能] 三十三.Bert模型 keras-bert库构建Bert模型实现文本分类

pytorch+huggingface实现基于bert模型的文本分类(附代码)

Bert文本分类实践