日常轻松一刻:是谁偷改了我的参数名称?
Posted 囚生CY
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了日常轻松一刻:是谁偷改了我的参数名称?相关的知识,希望对你有一定的参考价值。
文章目录
前言
半夜coding写bug发现一个很冷僻的Python嵌套函数的坑点,先发出来给大家看看有没有遇到过类似的问题,暂时还不是很搞得清楚原理是什么。
不知道大家写嵌套函数的时候有没有加下划线作为变量名称前缀的习惯,笔者的习惯一直都是每嵌套一层,嵌套函数的名称及其所有的新变量都额外添加一个下划线作为前缀,这样的好处是可以确保内外变量名不会发生重复,从而防止嵌套函数内意外改动外层函数的变量值。
比如下面这种代码结构的写法就比较满足我的强迫症(凑字数):
def generate_dataloader(args, mode='train', do_export=False, pipeline='judgment', for_debug=False):
dataset = BasicDataset(args=args,
mode=mode,
do_export=do_export,
pipeline=pipeline,
for_debug=for_debug)
column = dataset.data.columns.tolist()
if mode.startswith('train'):
batch_size = args.train_batch_size
shuffle = True
if mode.startswith('valid'):
batch_size = args.valid_batch_size
shuffle = False
if mode.startswith('test'):
batch_size = args.test_batch_size
shuffle = False
def _collate_fn(_batch_data):
def __collate_id():
return [__data['id'] for __data in _batch_data]
def __collate_type():
return [__data['type'] for __data in _batch_data] # 这个其实没有什么用, 因为我已经把他给扔了
def __collate_subject():
return torch.LongTensor([__data['subject'] for __data in _batch_data]) # 2022/04/10 17:37:59 这个之后要用的话可能还得转为onehot
def __collate_label_choice():
return torch.LongTensor([__data['label_choice'] for __data in _batch_data]) # 0-15的选择题答案编码值: 数据类型Long
def __collate_label_judgment():
return torch.LongTensor([__data['label_judgment'] for __data in _batch_data]) # 零一的判断题答案编码值: 数据类型Long
# 2022/04/10 17:42:29 目前基本上已经告别judgment的pipeline了, 因此这个字段处理差不多算是废弃了
def __collate_option_id():
return [__data['option_id'] for __data in _batch_data] # 选项号需要记录进来: 限判断题
# 2022/04/10 15:16:07 词嵌入的处理, 分为顺序编码(数据类型为Long)与标准嵌入(数据类型为Float)
if args.word_embedding is None and args.document_embedding is None:
# 不使用词向量或文档向量的情况, 即使用顺序编号编码, 数据类型是long
def __collate_question():
return torch.LongTensor([__data['question'] for __data in _batch_data])
def __collate_reference():
return torch.LongTensor([__data['reference'] for __data in _batch_data])
def __collate_options():
# 选择题特有字段: 四个选项
return torch.LongTensor([__data['options'] for __data in _batch_data])
def __collate_option():
# 判断题特有字段: 一个选项
return torch.LongTensor([__data['option'] for __data in _batch_data])
else:
# 否则即使用向量转化, 此时转化为float类型
def __collate_question():
if isinstance(_batch_data[0]['question'], numpy.ndarray):
return torch.FloatTensor([__data['question'] for __data in _batch_data])
elif isinstance(_batch_data[0]['question'], torch.Tensor):
return torch.stack([__data['question'] for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['question']))
def __collate_reference():
if isinstance(_batch_data[0]['reference'], numpy.ndarray):
return torch.FloatTensor([__data['reference'] for __data in _batch_data])
elif isinstance(_batch_data[0]['reference'], torch.Tensor):
return torch.stack([__data['reference'] for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['reference']))
def __collate_options():
if isinstance(_batch_data[0]['options'], numpy.ndarray):
return torch.FloatTensor([__data['options'] for __data in _batch_data])
elif isinstance(_batch_data[0]['options'], torch.Tensor):
return torch.stack([__data['options'] for __data in _batch_data])
elif isinstance(_batch_data[0]['options'], list):
# 2021/12/27 22:35:27 目前只有options可能存在batch_data的每一个元素是一个列表, 该列表里面是四个选项的嵌入向量
# 2021/12/27 22:36:51 原因是options可能会涉及要转为judgment形式, 需要expand, 我担心如果不是list的话可能会失败, 因此只保留了options的list格式, reference之前也是list, 我已经转为numpy.ndarray了
if isinstance(_batch_data[0]['options'][0], numpy.ndarray):
# 2022/02/21 12:07:21 如果使用的是doc2vec, 则是这种情况, 因为doc2vec是针对段落直接进行编码
return torch.FloatTensor(numpy.stack([numpy.stack(__data['options']) for __data in _batch_data]))
elif isinstance(_batch_data[0]['options'][0], torch.Tensor):
# 2022/02/21 12:08:24 如果使用的是BERT或者word2vec, 则会在dataset相关处理中被转为torch.Tensor
return torch.stack([torch.stack(__data['options']) for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['options'][0]))
else:
raise NotImplementedError(type(_batch_data[0]['options']))
def __collate_option():
if isinstance(_batch_data[0]['option'], numpy.ndarray):
return torch.FloatTensor([__data['option'] for __data in _batch_data])
elif isinstance(_batch_data[0]['option'], torch.Tensor):
return torch.stack([__data['option'] for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['option']))
# 2022/04/09 10:26:18 增加词性特征和句法树的处理逻辑
if args.use_pos_tags:
def __collate_pos_tags(__column):
# 2022/05/18 23:22:09 为了对应padding字符的编号为0, 决定将填补从-1改为0, 其余所有编号依次+1
# 2022/05/18 23:40:40 __data[__column]就是类似[NR, P, NN, AD, VV, PU, VV, VV, PN, PU, CD, M,...]的一维列表
# return torch.LongTensor([list(map(lambda __pos_tag: STANFORD_POS_TAG_INDEX.get(__pos_tag, -1), __data[__column])) for __data in _batch_data])
return torch.LongTensor([list(map(lambda __pos_tag: STANFORD_POS_TAG_INDEX.get(__pos_tag, -1) + 1, __data[__column])) for __data in _batch_data])
def __collate_statement_pos_tags():
return __collate_pos_tags(__column='statement_pos_tags')
def __collate_option_a_pos_tags():
return __collate_pos_tags(__column='option_a_pos_tags')
def __collate_option_b_pos_tags():
return __collate_pos_tags(__column='option_b_pos_tags')
def __collate_option_c_pos_tags():
return __collate_pos_tags(__column='option_c_pos_tags')
def __collate_option_d_pos_tags():
return __collate_pos_tags(__column='option_d_pos_tags')
# 2022/05/18 23:31:44 参考书目文档的词性标注处理
if args.use_reference:
def __collate_reference_pos_tags():
# 2022/05/18 23:49:03 与__collate_pos_tags的区别在于这里的__data['reference_pos_tags']是二维列表, 第一维是参考段落的数量
return torch.LongTensor([[list(map(lambda ___pos_tag: STANFORD_POS_TAG_INDEX.get(___pos_tag, -1) + 1, ___pos_tags)) for ___pos_tags in __data['reference_pos_tags']] for __data in _batch_data])
# 2022/04/15 22:29:45 若使用句法树, 则必然使用词性标注, 因此嵌套在该循环中
if args.use_parse_tree:
def __collate_parse_tree(__column):
return [[parse_tree_to_graph(__parse_tree) for __parse_tree in __data[__column]] for __data in _batch_data]
def __collate_statement_tree():
return __collate_parse_tree(__column='statement_tree')
def __collate_option_a_tree():
return __collate_parse_tree(__column='option_a_tree')
def __collate_option_b_tree():
return __collate_parse_tree(__column='option_b_tree')
def __collate_option_c_tree():
return __collate_parse_tree(__column='option_c_tree')
def __collate_option_d_tree():
return __collate_parse_tree(__column='option_d_tree')
if args.use_reference:
def __collate_reference_tree():
return [[[parse_tree_to_graph(___parse_tree) for ___parse_tree in ___parse_trees] for ___parse_trees in __data['reference_tree']] for __data in _batch_data]
_collate_data =
for _column in column:
_collate_data[_column] = eval(f'__collate__column')()
return _collate_data
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=args.num_workers,
collate_fn=_collate_fn,
shuffle=shuffle)
return dataloader
本来一切都很完美,直到开始出现问题。
问题发现
下面是一个出现坑的demo:
class Dataset:
def __init__(self):
pass
def demo(self):
def _easy_plus(_x):
def __easy_plus(__y):
return _x + __y
return __easy_plus
one_plus_function = _easy_plus(_x=1)
help(one_plus_function)
print(one_plus_function(__y=2))
dataset = Dataset()
dataset.demo()
输出结果如下:
Help on function __easy_plus in module __main__:
__easy_plus(_Dataset__y)
Traceback (most recent call last):
File "sanity_test.py", line 21, in <module>
dataset.demo()
File "sanity_test.py", line 18, in demo
print(one_plus_function(__y=1))
TypeError: __easy_plus() got an unexpected keyword argument '__y'
可以发现__easy_plus
函数的参数名称事实上被改动成为_Dataset__y
,调用one_plus_function(__y=2)
会报错,调用one_plus_function(_Dataset__y=2)
则可以正常返回结果。
但是如果将函数_easy_plus
与__easy_plus
的参数名称前缀的下划线分别去掉一个,改动成下面的形式:
class Dataset:
def __init__(self):
pass
def demo(self):
def _easy_plus(x):
def __easy_plus(_y):
return x + _y
return __easy_plus
one_plus_function = _easy_plus(x=1)
help(one_plus_function)
print(one_plus_function(_y=2))
dataset = Dataset()
dataset.demo()
输出结果如下:
Help on function __easy_plus in module __main__:
__easy_plus(_y)
3
可以发现参数名称又没有发生修改了。
通过观察规律,似乎是编译器默认参数名称包含很多下划线前缀时就会自动修改名称,于是笔者又写了下面的demo来验证这一规律:
class Dataset:
def __init__(self):
pass
def demo(self):
def _easy_plus(__x):
def __easy_plus(___y):
return __x + ___y
return __easy_plus
help(_easy_plus)
one_plus_function = _easy_plus(_Dataset__x=1)
help(one_plus_function)
dataset = Dataset()
dataset.demo()
输出结果如下:
Help on function _easy_plus in module __main__:
_easy_plus(_Dataset__x)
Help on function __easy_plus in module __main__:
__easy_plus(_Dataset___y)
的确如此,结果显示连_easy_plus
的参数名称都被修改了。
然而如果嵌套函数是写在嵌套函数之外呢?
def demo():
def _easy_plus(__x):
def __easy_plus(___y):
return __x + ___y
return __easy_plus
help(_easy_plus)
one_plus_function = _easy_plus(__x=1)
help(one_plus_function)
demo()
输出结果如下:
Help on function _easy_plus in module __main__:
_easy_plus(__x)
Help on function __easy_plus in module __main__:
__easy_plus(___y)
此时不管有多少下划线前缀也不会发生参数名称被修改的现象。
现在只能姑且总结规律为:类函数中的嵌套函数参数名称的下划线前缀超过两个就会触发参数名称修改。
虽然并不是很能理解确切的原理是什么… [Facepalm]
规律更新
其实后来发现其实不止是嵌套函数,只要是写在类中的函数,不管是静态方法还是别的对象方法,只要参数名称的下划线前缀超过两个都会触发这种机制的修改。比如下面的示例:
class Dataset:
def __init__(self):
pass
def _demo(self, __x):
return __x
dataset = Dataset()
help(dataset._demo)
输出结果为:
Help on method _demo in module __main__:
_demo(_Dataset__x) method of __main__.Dataset instance
因此规律应该是:类域内的函数参数名称的下划线前缀超过两个就会触发参数名称修改。
关于Python的类函数,因为没有明确的public
与private
的区分,因此一般默认函数名称会使用下划线作为前缀来区分是否应当被外部调用(即是否为私有的,虽然这也并不是强制的,想要调用带下划线作为名称前缀的类函数依然是可行的,比如对于list
类型的变量token_list
来说,调用token_list[0]
与token_list.__getitem__(0)
是完全等价的),但是关于函数的参数名称会被修改的确是从来没有注意到过的事情,希望我不是最后一个发现这个问题的倒霉球…[Facepalm]
问题解决
本着闲得蛋疼求真务实的精神,笔者还是去查了一下官方文档,总结下来确实跟上面猜想的一样,虽然Python类中会用下划线前缀来不严格的区分私有方法,但是用两个下划线作为前缀确实是严格的区分了私有变量。具体原理如下(其中第2点解释了原理,第345点说明了这种参数名称修改的原因):
(摘自https://docs.python.org/zh-cn/3/tutorial/classes.html#private-variables)
-
那种仅限从一个对象内部访问的私有实例变量在Python中并不存在。但是,大多数Python代码都遵循这样一个约定:带有一个下划线的名称(例如
_spam
)应该被当作是API的非公有部分 (无论它是函数、方法或是数据成员)。这应当被视为一个实现细节,可能不经通知即加以改变。 -
由于存在对于类私有成员的有效使用场景(例如避免名称与子类所定义的名称相冲突),因此存在对此种机制的有限支持,称为名称改写(就是这个玩意儿)。 任何形式为
__spam
的标识符(至少带有两个前缀下划线,至多一个后缀下划线)的文本将被替换为_classname__spam
,其中classname
为去除了前缀下划线的当前类名称。这种改写不考虑标识符的句法位置,只要它出现在类定义内部就会进行。 -
名称改写有助于让子类重载方法而不破坏类内方法调用。例如:
class Mapping: def __init__(self, iterable): self.items_list = [] self.__update(iterable) def update(self, iterable): for item in iterable: self.items_list.append(item) __update = update # private copy of original update() method class MappingSubclass(Mapping): def update(self, keys, values): # provides new signature for update() # but does not break __init__() for item in zip(keys, values): self.items_list.append(item)
-
上面的示例即使在
MappingSubclass
引入了一个__update
标识符的情况下也不会出错,因为它会在Mapping
类中被替换为_Mapping__update
而在MappingSubclass
类中被替换为_MappingSubclass__update
。 -
请注意,改写规则的设计主要是为了避免意外冲突;访问或修改被视为私有的变量仍然是可能的。这在特殊情况下甚至会很有用,例如在调试器中。
-
请注意传递给
exec()
或eval()
的代码不会将发起调用类的类名视作当前类;这类似于global
语句的效果,因此这种效果仅限于同时经过字节码编译的代码。 同样的限制也适用于getattr()
,setattr()
和delattr()
,以及对于__dict__
的直接引用。
后记
其实这个问题在不知情的情况下确实很难被发现,而且笔者发现按照上面的说法是不是其实__easy_plus
这个函数其实也被改名了,只是目前还没有发现它可能报错的点。总之在类中得要慎用双下划线前缀的写法了,强迫症该改还是得改。
以上是关于日常轻松一刻:是谁偷改了我的参数名称?的主要内容,如果未能解决你的问题,请参考以下文章