tokenizer.max len 在这个类定义中做了啥?

Posted

技术标签:

【中文标题】tokenizer.max len 在这个类定义中做了啥?【英文标题】:What is tokenizer.max len doing in this class definition?tokenizer.max len 在这个类定义中做了什么? 【发布时间】:2021-06-28 06:42:58 【问题描述】:

我正在按照 Rostylav 的教程找到 here,但遇到了一个我不太明白的错误:

AttributeError                            
Traceback (most recent call last)
<ipython-input-22-523c0d2a27d3> in <module>()
----> 1 main(trn_df, val_df)

<ipython-input-20-1f17c050b9e5> in main(df_trn, df_val)
     59     # Training
     60     if args.do_train:
---> 61         train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)
     62 
     63         global_step, tr_loss = train(args, train_dataset, model, tokenizer)

<ipython-input-18-3c4f1599e14e> in load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate)
     40 
     41 def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):
---> 42     return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)
     43 
     44 def set_seed(args):

<ipython-input-18-3c4f1599e14e> in __init__(self, tokenizer, args, df, block_size)
      8     def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):
      9 
---> 10         block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)
     11 
     12         directory = args.cache_dir

AttributeError: 'GPT2TokenizerFast' object has no attribute 'max_len'

这是我认为导致错误的类,但是我无法理解 Tokenize.max_len 应该做什么,因此我可以尝试修复它:

   class ConversationDataset(Dataset):

    def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):

        block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)

        directory = args.cache_dir
        cached_features_file = os.path.join(
            directory, args.model_type + "_cached_lm_" + str(block_size)
        )

        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", directory)

            self.examples = []
            for _, row in df.iterrows():
                conv = construct_conv(row, tokenizer)
                self.examples.append(conv)

            logger.info("Saving features into cached file %s", cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item], dtype=torch.long)
 
# Cacheing and storing of data/checkpoints

def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):
    return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)

感谢您的阅读!

【问题讨论】:

【参考方案1】:

属性max_len 是migrated 到model_max_length。它表示模型可以处理的最大令牌数(即包括特殊令牌)(documentation)。

另一边的max_len_single_sentence表示单个句子可以拥有的最大标记数(即没有特殊标记)(documentation)。

【讨论】:

你是一个绅士和一个学者,它立即解决了它;我也非常感谢您链接的信息,谢谢!

以上是关于tokenizer.max len 在这个类定义中做了啥?的主要内容,如果未能解决你的问题,请参考以下文章

ES之RestAPI实现自动补全

李宏毅 机器学习 p5学习 笔记

李宏毅 机器学习 p5学习 笔记

python进阶-- 04 如何定制类

定义一个类rectangle,描述一个矩形,包含有长、宽两种属性,以及计算面积的方法;

Python类的内置方法