由于内存问题,如何保存与预训练的 bert 模型的分类器层相关的参数?
Posted
技术标签:
【中文标题】由于内存问题,如何保存与预训练的 bert 模型的分类器层相关的参数?【英文标题】:How to save parameters just related to classifier layer of pretrained bert model due to the memory concerns? 【发布时间】:2021-10-19 04:03:05 【问题描述】:我通过冻结除分类器层之外的所有层来微调预训练模型here。我使用 pytorch 将重量文件保存为 .bin 格式。
现在不是加载400mb的预训练模型,有没有办法加载我重新训练的刚刚分类器层的参数?顺便说一句,我知道我必须加载原始的预训练模型,我只是不想加载整个微调模型。由于内存问题。
我可以从 state_dict 访问最后一层的参数,如下所示,但是如何将它们保存在单独的文件中以便以后使用它们以减少内存使用?
model = PosTaggingModel(num_pos_tag=num_pos_tag)
state_dict = torch.load("model.bin")
print("state dictionary:",state_dict)
with torch.no_grad():
model.out_pos_tag.weight.copy_(state_dict['out_pos_tag.weight'])
model.out_pos_tag.bias.copy_(state_dict['out_pos_tag.bias'])
这是模型类:
class PosTaggingModel(nn.Module):
def __init__(self, num_pos_tag):
super(PosTaggingModel, self).__init__()
self.num_pos_tag = num_pos_tag
self.model = AutoModel.from_pretrained("dbmdz/bert-base-turkish-cased")
for name, param in self.model.named_parameters():
if 'classifier' not in name: # classifier layer
param.requires_grad = False
self.bert_drop = nn.Dropout(0.3)
self.out_pos_tag = nn.Linear(768, self.num_pos_tag)
def forward(self, ids, mask, token_type_ids, target_pos_tag):
o1, _ = self.model(ids, attention_mask = mask, token_type_ids = token_type_ids)
bo_pos_tag = self.bert_drop(o1)
pos_tag = self.out_pos_tag(bo_pos_tag)
loss = loss_fn(pos_tag, target_pos_tag, mask, self.num_pos_tag)
return pos_tag, loss
我不知道这是否可能,但我只是在寻找一种方法来保存和重用最后一层的参数,而不需要冻结层的参数。我在documentation 中找不到它。 提前感谢那些愿意提供帮助的人。
【问题讨论】:
【参考方案1】:你可以这样做
import torch
# creating a dummy model
class Classifier(torch.nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.first = torch.nn.Linear(10, 10)
self.second = torch.nn.Linear(10, 20)
self.last = torch.nn.Linear(20, 1)
def forward(self, x):
pass
# Creating its object
model = Classifier()
# Extracting the layer to save
to_save = model.last
# Saving the state dict of that layer
torch.save(to_save.state_dict(), './classifier.bin')
# Recreating the object of that model
model = Classifier()
# Updating the saved layer of model
model.last.load_state_dict(torch.load('./classifier.bin'))
【讨论】:
这正是我想要的。非常感谢。以上是关于由于内存问题,如何保存与预训练的 bert 模型的分类器层相关的参数?的主要内容,如果未能解决你的问题,请参考以下文章
使用领域文本预训练 BERT/RoBERTa 语言模型,估计需要多长时间?哪个更快?