Bert文本分类实践

Posted 一只小鱼儿

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Bert文本分类实践相关的知识,希望对你有一定的参考价值。

 最近有一个需求需要给UGC作品打上标记,由于人力成本限制只能自己边学边做,借此机会学习一下文本分类模型,用下面的代码跑了一版离线实验结果,留个记录后面有空再进一步研究研究。

参考自:GitHub - Timaos123/BERTClassifier

import numpy as np
import pandas as pd
import os
import tqdm
import bert
from tensorflow import keras
from tensorflow.keras.models import Sequential
import tensorflow as tf
from sklearn.model_selection import train_test_split
import json
from sklearn.preprocessing import OneHotEncoder
import pickle as pkl
from sklearn.metrics import f1_score

#数据清洗
...

#定义模型
class MyBERTClassier:

    def __init__(self,
                classNum=2,
                preModelPath="chinese_L-12_H-768_A-12/",
                learning_rate=0.1,
                XMaxLen=5,
            ):
        self.preModelPath=preModelPath
        self.learning_rate=learning_rate
        self.XMaxLen=XMaxLen
        self.classNum=classNum
        self.maxLen=XMaxLen
        self.buildVocab()
        self.tokenizer=bert.bert_tokenization.FullTokenizer(os.path.join(self.preModelPath, "vocab.txt"), do_lower_case=True)

        self.buildModel()

    def buildVocab(self):
        with open(os.path.join(self.preModelPath,"vocab.txt"),"r",encoding="utf8") as vocabFile:
            self.XVocabList=[row.strip() for row in tqdm.tqdm(vocabFile)]
            self.XVocabSize=len(self.XVocabList)

    def removeUNK(self,seqList):
        return [[wItem for wItem in row if wItem in self.XVocabList] for row in seqList]

    def buildModel(self):
        
        inputLayer = keras.layers.Input(shape=(self.maxLen,), dtype='int32')

        bert_params = bert.params_from_pretrained_ckpt(self.preModelPath)
        bertLayer = bert.BertModelLayer.from_params(bert_params, name="bert")(inputLayer)

        flattenLayer = keras.layers.Flatten()(bertLayer)
        outputLayer = keras.layers.Dense(
            self.classNum, activation="softmax")(flattenLayer)

        self.model = keras.models.Model(inputLayer,outputLayer)
        self.model.compile(loss="SparseCategoricalCrossentropy",
                            optimizer=tf.keras.optimizers.RMSprop(learning_rate=self.learning_rate))

    def fit(self,X,y,epochs=1,batch_size=64):
        '''
        X:cutted seq
        y:cutted y
        '''

        X=np.array([self.tokenizer.convert_tokens_to_ids([wItem for wItem in row if wItem in self.XVocabList]) for row in X])
        X=np.array([row+[0]*(self.maxLen-len(row)) if len(row)<self.maxLen else row[:self.maxLen] for row in X.tolist()])
        y=y.astype(np.int32)

        preEpochs=3
        self.model.fit(X,y,epochs=preEpochs,batch_size=batch_size)
        self.model.layers[1].trainable=False
        
        es = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=15)
        self.model.fit(X, y, epochs=epochs-preEpochs,
                       batch_size=3*batch_size, callbacks=[es])
        
    def predict(self,X):

        X=np.array([self.tokenizer.convert_tokens_to_ids([wItem for wItem in row if wItem in self.XVocabList]) for row in X])
        X=np.array([row+[0]*(self.maxLen-len(row)) if len(row)<self.maxLen else row[:self.maxLen] for row in X.tolist()])

        preY=self.model.predict(X)
        
        return preY


#模型训练和评估
print("loading data ...")
blogDf = dataset

blogDf.dropna(inplace=True)
blogDf["text"] = blogDf["content_extract"].apply(lambda row:row.replace(" ","").strip())

blogDf.loc[blogDf["tab"]=="社團","tab"]="社团"
# blogDf['tab'].unique()

blogDf["class"] = blogDf["tab"]

print("restructure Y (could be tagged) ...")
classList=list(set(blogDf["class"].values.tolist()))
classNum=len(classList)
blogDf["class"]=blogDf["class"].apply(lambda row:classList.index(row))

print("prefix and suffix ...")
blogDf["text"]=blogDf["text"].apply(lambda row:["[CLS]"]+list(row)+["[SEP]"])

blogDf = blogDf[["text","class"]]
blog = blogDf.values
print(blog)

print("splitting train/test ...")
trainX,testX,trainY,testY=train_test_split(blog[:,0],blog[:,1],test_size=0.3)

print("building model ...")
seqList=blog[:,0].tolist()
maxLen=int(np.mean([len(row) for row in seqList]))
print("max_len:",maxLen)
myModel=MyBERTClassier(classNum,XMaxLen=maxLen,learning_rate=0.0001)
print(myModel.model.summary())

print("training model ...")
myModel.fit(trainX,trainY,epochs=7,batch_size=32)

print("testing model ...")
preY=myModel.predict(testX)
print("test f1:",f1_score(testY.astype(np.int32),np.argmax(preY,axis=-1),average="macro"))

以上是关于Bert文本分类实践的主要内容,如果未能解决你的问题,请参考以下文章

Bert文本分类实践

BERT 预训练模型及文本分类

广告行业中那些趣事系列6:BERT线上化ALBERT优化原理及项目实践(附github)

BERT-多标签文本分类实战之一——实战项目总览

BERT-多标签文本分类实战之一——实战项目总览

BERT-多标签文本分类实战之一——实战项目总览