使用 pytorch 进行 BERT 文本分类

Posted

技术标签:

【中文标题】使用 pytorch 进行 BERT 文本分类【英文标题】:BERT text clasisification using pytorch 【发布时间】:2021-11-13 21:08:19 【问题描述】:

我正在尝试借助此代码 [https://towardsdatascience.com/bert-text-classification-using-pytorch-723dfb8b6b5b] 构建用于文本分类的 BERT 模型。我的数据集包含两列(标签、文本)。 标签可以具有三个值(0,1,2)。代码没有任何错误,但混淆矩阵的所有值都是0。我的代码有问题吗?

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torchtext.data import Field, TabularDataset, BucketIterator, Iterator
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification
import torch.optim as optim
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

import seaborn as sns

torch.manual_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

MAX_SEQ_LEN = 128
PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)


label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True, fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_t>
fields = [('label', label_field), ('text', text_field)]
CLASSIFICATION_REPORT = "classification_report.jsonl"


train, valid, test = TabularDataset.splits(path='', train='train.csv', validation='validate.csv', test='test.csv', format='CSV', fields=fields, skip_header=True)

train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True)
test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)

class BERT(nn.Module):
        def __init__(self):
                super(BERT, self).__init__()
                options_name = "bert-base-uncased"
                self.encoder = BertForSequenceClassification.from_pretrained(options_name, num_labels = 3)

        def forward(self, text, label):
                loss, text_fea = self.encoder(text, labels=label)[:2]
                return loss, text_fea

def train(model, optimizer, criterion = nn.BCELoss(), train_loader = train_iter, valid_loader = valid_iter, num_epochs = 5, eval_every = len(train_iter) // 2, file_pat>        running_loss = 0.0
        valid_running_loss = 0.0
        global_step = 0
        train_loss_list = []
        valid_loss_list = []
        global_steps_list = []

        model.train()

        for epoch in range(num_epochs):
                for (label, text), _ in train_loader:
                        label = label.type(torch.LongTensor)
                        label = label.to(device)
                        text = text.type(torch.LongTensor)
                        text = text.to(device)
                        output = model(text, label)
                        loss, _ = output
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        global_step += 1
                        if global_step % eval_every == 0:
                                model.eval()
                                with torch.no_grad():
                                        for (label, text), _ in valid_loader:
                                                label = label.type(torch.LongTensor)
                                                label = label.to(device)
                                                text = text.type(torch.LongTensor)
                                                text = text.to(device)
                                                output = model(text, label)
                                                loss, _ = output
                                                valid_running_loss += loss.item()

                                average_train_loss = running_loss / eval_every
                                average_valid_loss = valid_running_loss / len(valid_loader)
                                train_loss_list.append(average_train_loss)
                                valid_loss_list.append(average_valid_loss)
                                global_steps_list.append(global_step)


                                # resetting running values
                                running_loss = 0.0
                                valid_running_loss = 0.0
                                model.train()

                                # print progress
                                print('Epoch [/], Step [/], Train Loss: :.4f, Valid Loss: :.4f'.format(epoch+1, num_epochs, global_step, num_epochs*len(tra>
                                if best_valid_loss > average_valid_loss:
                                        best_valid_loss = average_valid_loss
        print('Finished Training!')

model = BERT().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)

train(model=model, optimizer=optimizer)


def evaluate(model, test_loader):
        y_pred = []
        y_true = []
        model.eval()
        with torch.no_grad():
                for (label, text), _ in test_loader:
                        label = label.type(torch.LongTensor)
                        label = label.to(device)
                        text = text.type(torch.LongTensor)
                        text = text.to(device)
                        output = model(text, label)

                        _, output = output
                        y_pred.extend(torch.argmax(output, 2).tolist())
                        y_true.extend(label.tolist())
        print('Classification Report:')
        print(classification_report(y_true, y_pred, labels=[0,1,2], digits=4))
best_model = BERT().to(device)
evaluate(best_model, test_iter)

【问题讨论】:

你能粘贴你的混淆矩阵吗? 【参考方案1】:

您正在为多类分类问题使用标准 = nn.BCELoss(),二元交叉熵,“标签可以具有三个值 (0,1,2)”。使用合适的损失函数进行多类分类。

【讨论】:

您的答案可以通过额外的支持信息得到改进。请edit 添加更多详细信息,例如引用或文档,以便其他人可以确认您的答案是正确的。你可以找到更多关于如何写好答案的信息in the help center。

以上是关于使用 pytorch 进行 BERT 文本分类的主要内容,如果未能解决你的问题,请参考以下文章

小白学习PyTorch教程十五BERT:通过PyTorch来创建一个文本分类的Bert模型

小白学习PyTorch教程十六在多标签分类任务上 微调BERT模型

小白学习PyTorch教程十六在多标签分类任务上 微调BERT模型

使用transform 库及 PyTorch 进行 基于 albert 的文本分类任务

自然语言处理动手学Bert文本分类

自然语言处理动手学Bert文本分类