在掩蔽语言建模期间掩蔽每个输入句子中的特定标记
Posted
技术标签:
【中文标题】在掩蔽语言建模期间掩蔽每个输入句子中的特定标记【英文标题】:Masking specific token in each input sentence during Masked language modelling 【发布时间】:2021-12-04 17:55:29 【问题描述】:我有一个包含 2 列的数据集:token, sentence
。例如:
'token':'shrouded', 'sentence':'A mist shrouded the sun'
我想在 Masked Language Modeling 任务中微调其中一个 Huggingface Transformers 模型。 (现在我按照this 教程使用distilroberta-base
)
现在,我尝试在训练时专门屏蔽sentence
中的token
,而不是随机屏蔽。例如。 A mist [MASK] the sun
然后得到模型来预测token shrouded
。
现在我了解到,在随机屏蔽中,我们可以简单地使用 DataCollatorForLanguageModeling
并将其输入到 Trainer
。但是,在此用例中,必须在预处理阶段进行屏蔽。我不知道该怎么做。
这是目前为止的代码:
...
datasets = load_dataset('csv', data_files=['word_sentence_1.csv'])
model_checkpoint = "distilroberta-base"
def tokenize_function(examples):
return tokenizer(examples["sentence"])
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
f"model_name-word_sentence_1_1",
evaluation_strategy = "epoch",
learning_rate=2e-5,
weight_decay=0.01,
push_to_hub=False,
)
##### Need to remove this and add logic of static masking ####
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
data_collator=data_collator,
)
trainer.train()
【问题讨论】:
【参考方案1】:我自己解决了这个问题:
这个想法是手动进行遮罩,同时为遮罩提供“labels
”。我所要做的就是在tokenize_function
中进行一些更改并删除data_collator
。
MASK_TOKEN = tokenizer.convert_ids_to_tokens(tokenizer.mask_token_id)
MASK_TOKEN_ID = tokenizer.mask_token_id
def tokenize_function(examples):
usage_arr = [ examples['sentence'][i].replace(examples['word'][i], MASK_TOKEN) for i in range(len(examples['word']))]
tokenized_data = tokenizer(usage_arr, padding="max_length", truncation=True)
label_arr_list = []
for i in range(len(usage_arr)):
label_arr = [-100] * len(tokenized_data.input_ids[i])
if MASK_TOKEN_ID in tokenized_data.input_ids[i]:
label_arr[tokenized_data.input_ids[i].index(MASK_TOKEN_ID)] = tokenizer.convert_tokens_to_ids(examples['word'][i])
label_arr_list.append(label_arr)
tokenized_data['labels'] = label_arr_list
return tokenized_data
【讨论】:
以上是关于在掩蔽语言建模期间掩蔽每个输入句子中的特定标记的主要内容,如果未能解决你的问题,请参考以下文章