PyTorch笔记 - IMDB数据集文本分类项目模型与训练

Posted SpikeKing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch笔记 - IMDB数据集文本分类项目模型与训练相关的知识,希望对你有一定的参考价值。

IMDB数据集:Kaggle下载地址,影评的积极或消极分类的影评

PyTorch的Dataset:torchtext.datasets.IMDB

# pip install torchdata torchtext
# 版本号需要与PyTorch对齐

from torchtext.datasets import IMDB

IMDB文本分类,自定义网络:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import utils

import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB

from torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import os
import sys
import logging
import logging
logging.basicConfig(
    level=logging.WARN, stream=sys.stdout, \\
    format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")

VOCAB_SIZE = 15000

# step1 编写GCNN模型代码,门(Gate)卷积网络
class GCNN(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):
        super(GCNN, self).__init__()
        
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        nn.init.xavier_uniform_(self.embedding_table.weight)
        
        # 都是1维卷积
        self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)
        self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)
        
        self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)
        self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)
        
        self.output_linear1 = nn.Linear(64, 128)
        self.output_linear2 = nn.Linear(128, num_class)
        
    def forward(self, word_index):
        """
        定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出
        """
        # 1. 通过word_index得到word_embedding
        # word_index shape: [bs, max_seq_len]
        word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]
        
        # 2. 编写第一层1D门卷积模块,通道数在第2维
        word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]
        A = self.conv_A_1(word_embedding)
        B = self.conv_B_1(word_embedding)
        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]
        
        A = self.conv_A_2(H)
        B = self.conv_B_2(H)
        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]
        
        # 3. 池化并经过全连接层
        pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]
        linear1_output = self.output_linear1(pool_output)
        
        # 最后一层需要设置为隐含层数目
        logits = self.output_linear2(linear1_output)  # [bs, 2]
        
        return logits
        
        
# PyTorch官网的简单模型
class TextClassificationModel(nn.Module):
    """
    简单版embedding.DNN模型
    """
    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        
    def forward(self, token_index):
        # 词袋
        embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]
        return self.fc(embedded)
        

# step2 构建IMDB Dataloader
BATCH_SIZE = 64

def yeild_tokens(train_data_iter, tokenizer):
    for i, sample in enumerate(train_data_iter):
        label, comment = sample
        yield tokenizer(comment)  # 字符串转换为token索引的列表
        
train_data_iter = IMDB(root="./data", split="train")  # Dataset类型的对象
tokenizer = get_tokenizer("basic_english")
# 只使用出现次数大约20的token
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)  # 特殊索引设置为0
print(f'单词表大小: len(vocab)')

# 校对函数, batch是dataset返回值,主要是处理batch一组数据
def collate_fn(batch):
    """
    对DataLoader所生成的mini-batch进行后处理
    """
    target = []
    token_index = []
    max_length = 0  # 最大的token长度
    for i, (label, comment) in enumerate(batch):
        tokens = tokenizer(comment)
        token_index.append(vocab(tokens))  # 字符列表转换为索引列表
        
        # 确定最大的句子长度
        if len(tokens) > max_length:
            max_length = len(tokens)
        
        if label == "pos":
            target.append(0)
        else:
            target.append(1)

    token_index = [index + [0]*(max_length-len(index)) for index in token_index]
    # one-hot接收长整形的数据,所以要转换为int64
    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))


# step3 编写训练代码
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval, \\
          eval_step_interval, save_path, resume=""):
    """
    此处data_loader是map-style dataset
    """
    start_epoch = 0
    start_step = 0
    if resume != "":
        # 加载之前训练过的模型的参数文件
        logging.warning(f"loading from resume")
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        
    for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):
        ema_loss = 0
        num_batches = len(train_data_loader)
        
        for batch_index, (target, token_index) in enumerate(train_data_loader):
            optimizer.zero_grad()
            step = num_batches*(epoch_index) + batch_index + 1
            logits = model(token_index)
            # one-hot需要转换float32才可以训练
            bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))
            ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均loss
            bce_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定
            optimizer.step()   # 更新参数
            
            if step % log_step_interval == 0:
                logging.warning(f"epoch_index: epoch_index, batch_index: batch_index, ema_loss: ema_loss")
                
            if step % save_step_interval == 0:
                os.makedirs(save_path, exist_ok=True)
                save_file = os.path.join(save_path, f"step_step.pt")
                torch.save(
                    "epoch": epoch_index,
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': bce_loss,
                , save_file)
                logging.warning(f"checkpoint has been saved in save_file")
            
            if step % eval_step_interval == 0:
                logging.warning("start to do evaluation...")
                model.eval()
                ema_eval_loss = 0
                total_acc_account = 0
                total_account = 0
                for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):
                    total_account += eval_target.shape[0]
                    eval_logits = model(eval_token_index)
                    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()
                    eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits), F.one_hot(eval_target, num_classes=2).to(torch.float32))
                    ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_loss
                logging.warning(f"ema_eval_loss: ema_eval_loss, eval_acc: total_acc_account / total_account")
                model.train()

# model = GCNN()
model = TextClassificationModel()
print("模型总参数:", sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_data_iter = IMDB(root="data", split="train")  # Dataset类型的对象
train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

eval_data_iter = IMDB(root="data", split="test")  # Dataset类型的对象
# collate校对
eval_data_loader = utils.data.DataLoader(
    to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

# resume = "./data/step_500.pt"
resume = ""

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, \\
      save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_text_classification", resume=resume)

以上是关于PyTorch笔记 - IMDB数据集文本分类项目模型与训练的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch笔记 - IMDB数据集文本分类项目模型与训练

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存

PT之Transformer:基于PyTorch框架利用Transformer算法针对IMDB数据集实现情感分类的应用案例代码解析

学习笔记TF019:序列分类IMDB影评分类

IMDB影评倾向分类 - N-Gram

MobileNetV3 实战:植物幼苗分类(pytorch)