使用/指定 attention_mask 使用 Trainer 和 TrainingArguments 训练 GPT2

Posted

技术标签:

【中文标题】使用/指定 attention_mask 使用 Trainer 和 TrainingArguments 训练 GPT2【英文标题】:Train GPT2 with Trainer & TrainingArguments using/specifying attention_mask 【发布时间】:2021-07-21 02:45:23 【问题描述】:

我正在使用 Trainer & TrainingArguments 来训练 GPT2 模型,但这似乎效果不佳。

我的数据集有我的语料库标记的 id 和每个文本的掩码,以指示在哪里应用注意力:

Dataset(
features: ['attention_mask', 'input_ids', 'labels'],
num_rows: 2012860
))

我正在使用 Trainer & TrainingArguments 进行培训,传递我的模型和我之前的数据集,如下所示。但是我没有在任何地方指定关于 attention_mask 的任何内容:

training_args = TrainingArguments(
output_dir=path_save_checkpoints,
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
logging_steps = 5_000, save_steps=5_000,
fp16=True,
deepspeed="ds_config.json",
remove_unused_columns = True,
debug = True
)

trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
tokenizer=tokenizer,
)

trainer.train()

我应该如何告诉 Trainer 使用此功能 (attention_mask)? 如果您查看文件 /transformers/trainer.py,则没有提及“注意”或“掩码”。

提前致谢!

【问题讨论】:

【参考方案1】:

在源代码的某处,您会看到输入被传递给模型,就像这样

outputs = model(**inputs)

只要您的整理器返回包含 attention_mask 键的字典,您的注意力掩码就会传递给您的 GPT2 模型。

【讨论】:

以上是关于使用/指定 attention_mask 使用 Trainer 和 TrainingArguments 训练 GPT2的主要内容,如果未能解决你的问题,请参考以下文章

使用 pytorch 闪电进行多 GPU 训练时出错

tput:在使用 Ruby Net:SSH 时没有指定 $TERM 的值并且没有指定 -T

T-SQL拆分使用指定分隔符的字符串(split string)

泛型约束

习题6-4 使用函数输出指定范围内的Fibonacci数 (20 分)

C#泛型约束