带有 wiki 文章的 GPT2 输入大小
Posted
技术标签:
【中文标题】带有 wiki 文章的 GPT2 输入大小【英文标题】:GPT2 input size with wiki articles 【发布时间】:2020-07-10 02:54:15 【问题描述】:我正在尝试在***文本上使用一个小型预训练 GPT2 模型。我尝试使用尽可能多的文本作为 gpt2 模型的输入。该模型为我总结了文本。如何使用整个 wiki 文章进行输入。就像现在一样,我仅限于 768 个令牌之类的东西。典型的 wiki 文章比这更长。使用超过 768 个标记的文本段落有什么技巧吗?
【问题讨论】:
【参考方案1】:是的,你可以这样做。
但您需要记住一件事,模型的input_shape
保持不变,因此您必须指定要为模型提供输入的maximum
序列长度。
现在回到你能做什么 -
如果改变模型的输入,后续层的输入形状也会改变。
您可以做的是复制模型架构及其预训练的权重,并使用所需的输入形状对其进行微调。无论您使用的是 PyTorch 还是任何其他框架都没有关系。只有input_shape
会根据您的要求进行更改。 PyTorch 将模型权重保存在 OrderedDict
中,您可以加载预训练模型并从那里复制权重。
示例:
model1 = TheModelClass(*args, **kwargs)
model1.load_state_dict(torch.load(PATH_TO_PRETRAINED_MODEL))
model1.eval()
model2 = TheNewModelClass(*args, **kwargs)
params1 = model1.named_parameters()
params2 = model2.named_parameters()
dict_params2 = dict(params2)
for name1, param1 in params1:
if name1 in dict_params2:
dict_params2[name1].data.copy_(param1.data)
您可以在此处的 PyTorch 论坛中找到一些很好的参考链接:
https://discuss.pytorch.org/t/copying-weights-from-one-net-to-another/1492
https://discuss.pytorch.org/t/copy-weights-only-from-a-networks-parameters/5841
【讨论】:
我正在使用 pytorch。在我看来,你在说它不能完成。我根本不想改变模型。 嗨!无论您使用的是 PyTorch 还是任何其他框架都没有关系。我根本不会告诉你改变模型。整个架构保持不变。只有输入根据您的要求更改。 PyTorch 将模型权重保存在OrderedDict
中,您可以加载模型并从那里复制权重。
如果您对复制权重有进一步的疑问,我也可以为您提供帮助。随意在下面发表评论。以上是关于带有 wiki 文章的 GPT2 输入大小的主要内容,如果未能解决你的问题,请参考以下文章