bert文本分类代码解析及accelerate使用

Posted zhouzhou0929

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了bert文本分类代码解析及accelerate使用相关的知识,希望对你有一定的参考价值。

数据格式及预处理

	工作以来代码基本靠git-clone,很多细节都不会或者忘了,最近看了hugging-face的accelerate,决定选个项目对比一下。
	数据选取的头条新闻数据,格式如下:


对数据进行处理,生成训练集和验证集csv,以及标签映射

import json
import pandas as pd
from sklearn.model_selection import train_test_split

with open("data/raw_data/toutiao_cat_data.txt", "r", encoding="utf-8") as f:
    data = f.readlines()
    data = [i.strip() for i in data]
    text = [i.split("_!_")[3] for i in data]
    label = [i.split("_!_")[2] for i in data]
    labels = list(set(label))
    label2id = dict(zip(labels, range(len(labels))))
    id2label = dict(zip(range(len(labels)), labels))
    label = [label2id[i] for i in label]

with open("src/conf/label2id.json", "w", encoding="utf-8") as f:
    json.dump(label2id, f, ensure_ascii=False)
with open("src/conf/id2label.json", "w", encoding="utf-8") as f:
    json.dump(id2label, f, ensure_ascii=False)

X_train, X_valid, y_train, y_valid = train_test_split(text, label, test_size=0.2, random_state=9, shuffle=True)
df_train = pd.DataFrame("text": X_train, "label":y_train)
df_valid = pd.DataFrame("text": X_valid, "label":y_valid)
df_train.to_csv("data/train_data/train_0225.csv", sep=",", index=False)
df_valid.to_csv("data/train_data/valid_0225.csv", sep=",", index=False)

Dataset

继承torch的Dataset类

from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
    def __init__(self, path):
        self.df = pd.read_csv(path, sep=",")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df["text"][idx]
        label = self.df["label"][idx]
        return text, label
train_set = TextDataset("data/train_data/train_0225.csv")
valid_set = TextDataset("data/train_data/valid_0225.csv")

Dataloader

Dataloader注意两个点:num_workers=cpu核心数合理,训练集shuffle=True,

train_loader = DataLoader(train_set, batch_size=512, shuffle=True, num_workers=32)
valid_loader = DataLoader(valid_set, batch_size=512, shuffle=False, num_workers=32)

Tokenizer

tokenizer之后的结果包含三部分信息:
1.input_ids 文本转id的结果
2.token_type_ids 如果输入是句子对,用0/1区分前后句
3.attention_mask 用0/1区分原文本和padding部分

inputs = tokenizer.batch_encode_plus(list(text), max_length=50, padding="max_length",  truncation=True)

bert输入

转为long tensor后传入bert

input_ids = torch.tensor(inputs["input_ids"],dtype=torch.long)
token_type_ids = torch.tensor(inputs["token_type_ids"],dtype=torch.long)
attention_mask = torch.tensor(inputs["attention_mask"],dtype=torch.long)

bert输出

bert输出有两部分:
1.pooler_output: 取cls对应的输出,加线性层和tanh激活函数后的输出[batch x 768]
2.last_hidden_state:最后隐层的输出[batch x max_length x 768]

outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
print(outputs["pooler_output"].size())
print(outputs["last_hidden_state"].size())

torch.Size([4, 768])
torch.Size([4, 12, 768])

对于分类任务一般取pooler_output,加线性层将结果映射到分类的类别数
ner任务取last_hidden_state去除cls的结果

完整的bert-class-model代码如下

class BertClass(nn.Module):
    def __init__(self):
        super(BertClass, self).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        with open("src/conf/label2id.json", "r", encoding="utf-8") as f:
            self.label2id = json.load(f)
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
        self.model = BertModel.from_pretrained("bert-base-chinese").to(self.device)
        self.linear = nn.Linear(768, len(self.label2id)).to(self.device)  # 最终映射到的类别数15

    def forward(self, inputs):
        input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(self.device)
        token_type_ids = torch.tensor(inputs["token_type_ids"], dtype=torch.long).to(self.device)
        attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(self.device)
        outputs = self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        outputs = self.linear(outputs["pooler_output"])
        return outputs

训练

model = BertClass()

def train(epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=5e-5)
    for i in range(epochs):
        acc = 0
        for text, label in tqdm(train_loader):
            optimizer.zero_grad()
            label = torch.tensor(label, dtype=torch.long).to(model.device)
            inputs = model.tokenizer.batch_encode_plus(list(text), max_length=50, padding="max_length", truncation=True)
            outs = model(inputs)
            batch_out = outs.clone().detach().requires_grad_(False)
            acc += (batch_out.argmax(1) == label).float().sum().item()
            loss = criterion(outs, label)
            loss.backward()
            optimizer.step()
        print(f"acc for epochi is ", acc/(len(train_loader)*512))
        valid()

def valid():
    with torch.no_grad():
        acc = 0
        for text, label in tqdm(valid_loader):
            label = torch.tensor(label, dtype=torch.long).to(model.device)
            inputs = model.tokenizer.batch_encode_plus(list(text), max_length=50, padding="max_length", truncation=True)
            outs = model(inputs)
            batch_out = outs.clone().detach().requires_grad_(False)
            acc += (batch_out.argmax(1) == label).float().sum().item()
        print(f"acc for valid is ", acc/(len(valid_loader)*512))

train(6)

准确率和训练时间:验证集acc 0.89 ,训练一个epoch 花费4:53
100%|█████████████████████████████████████████| 598/598 [04:53<00:00, 2.04it/s]
acc for epoch1 is 0.9197226431856187
100%|█████████████████████████████████████████| 150/150 [00:42<00:00, 3.54it/s]
acc for valid is 0.8920833333333333

accelerate改造

未完待续···

以上是关于bert文本分类代码解析及accelerate使用的主要内容,如果未能解决你的问题,请参考以下文章

使用 pytorch 进行 BERT 文本分类

[Python人工智能] 三十三.Bert模型 keras-bert库构建Bert模型实现文本分类

Bert实战:使用Bert实现文本分类。

自然语言处理动手学Bert文本分类

自然语言处理动手学Bert文本分类

BERT实战:使用DistilBERT进行文本情感分类