如何在 Pytorch Lightning 微调之前测试模型?
Posted
技术标签:
【中文标题】如何在 Pytorch Lightning 微调之前测试模型?【英文标题】:How to test a model before fine-tuning in Pytorch Lightning? 【发布时间】:2021-11-13 20:27:42 【问题描述】:在 Google Colab 上做事。
变压器:4.10.2 pytorch-闪电:1.2.7import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertForSequenceClassification
import pytorch_lightning as pl
dataset_for_loader = [
'data':torch.tensor([0,1]), 'labels':torch.tensor(0),
'data':torch.tensor([2,3]), 'labels':torch.tensor(1),
'data':torch.tensor([4,5]), 'labels':torch.tensor(2),
'data':torch.tensor([6,7]), 'labels':torch.tensor(3),
]
loader = DataLoader(dataset_for_loader, batch_size=2)
for idx, batch in enumerate(loader):
print(f'# batch idx')
print(batch)
category_list = [
'dokujo-tsushin',
'it-life-hack',
'kaden-channel',
'livedoor-homme',
'movie-enter',
'peachy',
'smax',
'sports-watch',
'topic-news'
]
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
max_length = 128
dataset_for_loader = []
for label, category in enumerate(tqdm(category_list)):
# file ./text has lots of articles, categorized by category
# and they are just plain texts, whose content begins from forth line
for file in glob.glob(f'./text/category/category*'):
lines = open(file).read().splitlines()
text = '\n'.join(lines[3:])
encoding = tokenizer(
text,
max_length=max_length,
padding='max_length',
truncation=True
)
encoding['labels'] = label
encoding = k: torch.tensor(v) for k, v in encoding.items()
dataset_for_loader.append(encoding)
SEED=lambda:0.0
# random.shuffle(dataset_for_loader) # ランダムにシャッフル
random.shuffle(dataset_for_loader,SEED)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]
dataloader_train = DataLoader(
dataset_train, batch_size=32, shuffle=True
)
dataloader_val = DataLoader(dataset_val, batch_size=256)
dataloader_test = DataLoader(dataset_test, batch_size=256)
class BertForSequenceClassification_pl(pl.LightningModule):
def __init__(self, model_name, num_labels, lr):
super().__init__()
self.save_hyperparameters()
self.bert_sc = BertForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels
)
def training_step(self, batch, batch_idx):
output = self.bert_sc(**batch)
loss = output.loss
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
output = self.bert_sc(**batch)
val_loss = output.loss
self.log('val_loss', val_loss)
def test_step(self, batch, batch_idx):
labels = batch.pop('labels')
output = self.bert_sc(**batch)
labels_predicted = output.logits.argmax(-1)
num_correct = ( labels_predicted == labels ).sum().item()
accuracy = num_correct/labels.size(0)
self.log('accuracy', accuracy)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
checkpoint = pl.callbacks.ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
save_weights_only=True,
dirpath='model/',
)
trainer = pl.Trainer(
gpus=1,
max_epochs=10,
callbacks = [checkpoint]
)
model = BertForSequenceClassification_pl(
MODEL_NAME, num_labels=9, lr=1e-5
)
### (a) ###
# I think this is where I am doing fine-tuning
trainer.fit(model, dataloader_train, dataloader_val)
# this is to score after fine-tuning
test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: test[0]["accuracy"]:.2f')
但我不太确定在微调之前如何进行测试,以便比较微调前后的两个模型,以显示微调的效果。
将以下两行插入### (a) ###
:
test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: test[0]["accuracy"]:.2f')
我得到了这个结果:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-13-c8b2c67f2d5c> in <module>()
9
10 # 6-19
---> 11 test = trainer.test(test_dataloaders=dataloader_test)
12 print(f'Accuracy: test[0]["accuracy"]:.2f')
13
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path, verbose, datamodule)
896 self.verbose_test = verbose
897
--> 898 self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
899
900 # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _set_running_stage(self, stage, model_ref)
563 the trainer and the model
564 """
--> 565 model_ref.running_stage = stage
566 self._running_stage = stage
567
AttributeError: 'NoneType' object has no attribute 'running_stage'
我注意到Trainer.fit()
can take None
as arguments other than model
,所以我尝试了这个:
trainer.fit(model)
test=trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: test[0]["accuracy"]:.2f')
结果:
MisconfigurationException: No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.
谢谢。
【问题讨论】:
【参考方案1】:Trainer
需要调用它的.fit()
才能设置很多东西,然后只有你可以做.test()
或其他方法。
在.test()
之前放置一个.fit()
是对的,但是fit 调用需要一个有效的调用。您必须为其提供数据加载器/数据模块。但由于您不想在此 fit 调用中进行训练/验证,因此只需在 Trainer 构建时传递 limit_[train/val]_batches=0
。
trainer = Trainer(gpus=..., ..., limit_train_batches=0, limit_val_batches=0)
trainer.fit(model, dataloader_train, dataloader_val)
trainer.test(model, dataloader_test) # without fine-tuning
此处的 fit 调用只会为您进行设置并跳过培训/验证。然后进行测试。下次运行相同的代码但没有limit_[train/val]_batches
,这将为您进行预训练
trainer = Trainer(gpus=..., ...)
trainer.fit(model, dataloader_train, dataloader_val)
trainer.test(model, dataloader_test) # with fine-tuning
澄清一点关于.fit()
将None
用于除模型之外的所有内容:它并不完全正确 - 您必须提供任一 DataLoader 或 DataModule。
【讨论】:
当我.test()
: MisconfigurationException: ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.
但 .test(..., ckpt_path=None)
似乎工作时,现在无法测试,虽然不确定是否真的正确。
没有微调的.test()
调用必须有ckpt_path=None
。微调后的.test()
调用,如果报错,尝试将它保存的微调模型的路径传递给ckpt_path
以上是关于如何在 Pytorch Lightning 微调之前测试模型?的主要内容,如果未能解决你的问题,请参考以下文章
如何从Pytorch 到 Pytorch Lightning | 简要介绍
如何禁用 PyTorch-Lightning 记录器的日志记录?
如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?
如何在 PyTorch Lightning 中编写多个训练设置