如何获得 T5 变压器模型的可重现结果
Posted
技术标签:
【中文标题】如何获得 T5 变压器模型的可重现结果【英文标题】:How to get reproducible results of T5 transformer model 【发布时间】:2021-02-26 13:33:59 【问题描述】:我正在尝试获得 T5 变压器模型的可重现结果:
import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer
def set_seed(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(42)
t5model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
device = torch.device("cpu")
print ("device ",device)
t5model = t5model.to(device)
max_len = 256
text = "paraphrase: " + txt + " </s>"
encoding = tokenizer.encode_plus(text,pad_to_max_length=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
beam_outputs = t5model.generate(
input_ids=input_ids, attention_mask=attention_masks,
do_sample=True,
max_length=max_len,
top_k=50,
top_p=0.98,
early_stopping=True,
num_return_sequences=10,
)
虽然我设置了种子编号,但t5model.generate
每次运行时都会给我不同的结果。
设置种子号的正确方法是什么,以便在多次执行后得到t5model.generate
的相同结果?
【问题讨论】:
【参考方案1】:您需要重新加载模型的 state_dict 以每次产生相同的输出。
这里发生的情况是,T5 模型初始化正在调用 pytorch 随机数生成器。这意味着,每次运行以下代码时,都会得到相同的输出:
set_seed(42)
t5model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_paraphraser')
t5model = t5model.to(device)
beam_outputs = []
for x in range(3):
beam_outputs.append(t5model.generate(
input_ids=input_ids, attention_mask=attention_masks,
do_sample=True,
max_length=max_len,
top_k=50,
top_p=0.98,
early_stopping=True,
num_return_sequences=5,
))
tokenizer.batch_decode([y for x in beam_outputs for y in x])
设置随机数生成器的种子并不意味着每次调用它都会生成相同的输出,而是意味着生成的数字序列由相同的种子初始化(查看此link 了解更多信息信息):
torch.manual_seed(42)
print(torch.randn(2))
print(torch.randn(2))
print(torch.randn(2))
torch.manual_seed(42)
print(torch.randn(2))
print(torch.randn(2))
print(torch.randn(2))
输出:
tensor([0.3367, 0.1288])
tensor([0.2345, 0.2303])
tensor([-1.1229, -0.1863])
tensor([0.3367, 0.1288])
tensor([0.2345, 0.2303])
tensor([-1.1229, -0.1863])
【讨论】:
以上是关于如何获得 T5 变压器模型的可重现结果的主要内容,如果未能解决你的问题,请参考以下文章