bert文本分类代码解析及accelerate使用
Posted zhouzhou0929
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了bert文本分类代码解析及accelerate使用相关的知识,希望对你有一定的参考价值。
数据格式及预处理
工作以来代码基本靠git-clone,很多细节都不会或者忘了,最近看了hugging-face的accelerate,决定选个项目对比一下。
数据选取的头条新闻数据,格式如下:
对数据进行处理,生成训练集和验证集csv,以及标签映射
import json
import pandas as pd
from sklearn.model_selection import train_test_split
with open("data/raw_data/toutiao_cat_data.txt", "r", encoding="utf-8") as f:
data = f.readlines()
data = [i.strip() for i in data]
text = [i.split("_!_")[3] for i in data]
label = [i.split("_!_")[2] for i in data]
labels = list(set(label))
label2id = dict(zip(labels, range(len(labels))))
id2label = dict(zip(range(len(labels)), labels))
label = [label2id[i] for i in label]
with open("src/conf/label2id.json", "w", encoding="utf-8") as f:
json.dump(label2id, f, ensure_ascii=False)
with open("src/conf/id2label.json", "w", encoding="utf-8") as f:
json.dump(id2label, f, ensure_ascii=False)
X_train, X_valid, y_train, y_valid = train_test_split(text, label, test_size=0.2, random_state=9, shuffle=True)
df_train = pd.DataFrame("text": X_train, "label":y_train)
df_valid = pd.DataFrame("text": X_valid, "label":y_valid)
df_train.to_csv("data/train_data/train_0225.csv", sep=",", index=False)
df_valid.to_csv("data/train_data/valid_0225.csv", sep=",", index=False)
Dataset
继承torch的Dataset类
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, path):
self.df = pd.read_csv(path, sep=",")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
text = self.df["text"][idx]
label = self.df["label"][idx]
return text, label
train_set = TextDataset("data/train_data/train_0225.csv")
valid_set = TextDataset("data/train_data/valid_0225.csv")
Dataloader
Dataloader注意两个点:num_workers=cpu核心数合理,训练集shuffle=True,
train_loader = DataLoader(train_set, batch_size=512, shuffle=True, num_workers=32)
valid_loader = DataLoader(valid_set, batch_size=512, shuffle=False, num_workers=32)
Tokenizer
tokenizer之后的结果包含三部分信息:
1.input_ids 文本转id的结果
2.token_type_ids 如果输入是句子对,用0/1区分前后句
3.attention_mask 用0/1区分原文本和padding部分
inputs = tokenizer.batch_encode_plus(list(text), max_length=50, padding="max_length", truncation=True)
bert输入
转为long tensor后传入bert
input_ids = torch.tensor(inputs["input_ids"],dtype=torch.long)
token_type_ids = torch.tensor(inputs["token_type_ids"],dtype=torch.long)
attention_mask = torch.tensor(inputs["attention_mask"],dtype=torch.long)
bert输出
bert输出有两部分:
1.pooler_output: 取cls对应的输出,加线性层和tanh激活函数后的输出[batch x 768]
2.last_hidden_state:最后隐层的输出[batch x max_length x 768]
outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
print(outputs["pooler_output"].size())
print(outputs["last_hidden_state"].size())
torch.Size([4, 768])
torch.Size([4, 12, 768])
对于分类任务一般取pooler_output,加线性层将结果映射到分类的类别数
ner任务取last_hidden_state去除cls的结果
完整的bert-class-model代码如下
class BertClass(nn.Module):
def __init__(self):
super(BertClass, self).__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
with open("src/conf/label2id.json", "r", encoding="utf-8") as f:
self.label2id = json.load(f)
self.tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
self.model = BertModel.from_pretrained("bert-base-chinese").to(self.device)
self.linear = nn.Linear(768, len(self.label2id)).to(self.device) # 最终映射到的类别数15
def forward(self, inputs):
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(self.device)
token_type_ids = torch.tensor(inputs["token_type_ids"], dtype=torch.long).to(self.device)
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(self.device)
outputs = self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
outputs = self.linear(outputs["pooler_output"])
return outputs
训练
model = BertClass()
def train(epochs):
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=5e-5)
for i in range(epochs):
acc = 0
for text, label in tqdm(train_loader):
optimizer.zero_grad()
label = torch.tensor(label, dtype=torch.long).to(model.device)
inputs = model.tokenizer.batch_encode_plus(list(text), max_length=50, padding="max_length", truncation=True)
outs = model(inputs)
batch_out = outs.clone().detach().requires_grad_(False)
acc += (batch_out.argmax(1) == label).float().sum().item()
loss = criterion(outs, label)
loss.backward()
optimizer.step()
print(f"acc for epochi is ", acc/(len(train_loader)*512))
valid()
def valid():
with torch.no_grad():
acc = 0
for text, label in tqdm(valid_loader):
label = torch.tensor(label, dtype=torch.long).to(model.device)
inputs = model.tokenizer.batch_encode_plus(list(text), max_length=50, padding="max_length", truncation=True)
outs = model(inputs)
batch_out = outs.clone().detach().requires_grad_(False)
acc += (batch_out.argmax(1) == label).float().sum().item()
print(f"acc for valid is ", acc/(len(valid_loader)*512))
train(6)
准确率和训练时间:验证集acc 0.89 ,训练一个epoch 花费4:53
100%|█████████████████████████████████████████| 598/598 [04:53<00:00, 2.04it/s]
acc for epoch1 is 0.9197226431856187
100%|█████████████████████████████████████████| 150/150 [00:42<00:00, 3.54it/s]
acc for valid is 0.8920833333333333
accelerate改造
未完待续···
以上是关于bert文本分类代码解析及accelerate使用的主要内容,如果未能解决你的问题,请参考以下文章