在 python3 中:list(iterables) 的奇怪行为
Posted
技术标签:
【中文标题】在 python3 中:list(iterables) 的奇怪行为【英文标题】:In python3: strange behaviour of list(iterables) 【发布时间】:2020-03-24 06:28:02 【问题描述】:我有一个关于 python 中迭代行为的具体问题。我的可迭代是 pytorch 中自定义构建的 Dataset 类:
import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
def __init__(self, X):
self.X = X
def __len__(self):
return len(self.X)
def __getitem__(self, x):
print('***********')
print('getitem x = ', x)
print('###########')
y = self.X[x]
print('getitem y = ', y)
return y
现在,当我初始化该 datasetTest 类的特定实例时,就会出现奇怪的行为。根据我作为参数 X 传递的数据结构,当我调用 list(datasetTestInstance) 时,它的行为会有所不同。特别是,当传递一个 torch.tensor 作为参数时没有问题,但是当传递一个 dict 作为参数时,它会抛出一个 KeyError。原因是 list(iterable) 不仅调用 i=0, ..., len(iterable)-1,而且调用 i=0, ..., len(iterable)。也就是说,它将迭代直到(包括)索引等于可迭代的长度。显然,这个索引在任何 python 数据结构中都没有定义,因为最后一个元素总是有索引 len(datastructure)-1 而不是 len(datastructure)。如果 X 是 torch.tensor 或列表,则不会出现错误,即使我认为应该是错误。即使对于索引为 len(datasetTestinstance) 的(不存在的)元素,它仍然会调用 getitem,但它不会计算 y=self.X[len(datasetTestInstance]。有谁知道 pytorch 是否在内部以某种方式优雅地处理这个问题?
当将 dict 作为数据传递时,它将在最后一次迭代中抛出错误,此时 x=len(datasetTestInstance)。这实际上是我猜想的预期行为。但为什么这只发生在 dict 而不是 list 或 torch.tensor?
if __name__ == "__main__":
a = datasetTest(torch.randn(5,2))
print(len(a))
print('++++++++++++')
for i in range(len(a)):
print(i)
print(a[i])
print('++++++++++++')
print(list(a))
print('++++++++++++')
b = datasetTest(0: 12, 1:35, 2:99, 3:27, 4:33)
print(len(b))
print('++++++++++++')
for i in range(len(b)):
print(i)
print(b[i])
print('++++++++++++')
print(list(b))
如果您想更好地理解我所观察到的内容,可以尝试使用该 sn-p 代码。
我的问题是:
1.) 为什么 list(iterable) 会迭代直到(包括)len(iterable)? for 循环不会这样做。
2.) 在作为数据 X 传递的 torch.tensor 或列表的情况下:为什么即使调用索引 len(datasetTestInstance) 的 getitem 方法也不会引发错误,因为它实际上应该超出范围未定义为张量/列表中的索引?或者,换句话说,当到达索引 len(datasetTestInstance) 然后进入 getitem 方法时,究竟发生了什么?它显然不再调用'y = self.X[x]'(否则会有一个IndexError),但它确实进入了getitem方法,我可以看到它从getitem方法中打印索引x。那么这种方法会发生什么?为什么它的行为会根据是 torch.tensor/list 还是 dict 而有所不同?
【问题讨论】:
关于第1点。通常我们使用for item in b:
通过for循环迭代可迭代类型。在这种情况下,python 期望 b
引发 IndexError
以指示已到达列表末尾。 (有关特定文档链接,请参阅我的答案)。
你为什么不使用迭代器协议?以这种方式使对象可迭代是非常过时的,并且只保留不破坏向后兼容性,AFAIK
【参考方案1】:
一堆有用的链接:
-
[Python 3.Docs]: Data model - Emulating container types
[Python 3.Docs]: Built-in Types - Iterator Types
[Python 3.Docs]: Built-in Functions - iter(object[, sentinel])
[SO]: Why does list ask about __len__?(所有答案)
关键是 list 构造函数使用 (iterable) 参数的 __len__ ((如果提供)来计算新的容器长度),但随后对其进行迭代(通过迭代器协议)。
由于非常巧合(请记住,dict 支持迭代器协议,这发生在它的键上(这是一个序列)):
您的字典只有 int 键(还有更多) 它们的值与它们的索引相同(按顺序)改变以上 2 个项目符号所表达的任何条件,都会使实际的错误更有说服力。
两个对象(dict 和 list (of tensors))都支持迭代器协议。为了使事情正常工作,您应该将它包装在您的 Dataset 类中,并稍微调整映射类型之一(使用值而不是键)。 代码(key_func 相关部分)有点复杂,但只是为了易于配置(如果您想更改某些内容 - 用于 demo 目的)。 p>
code00.py:
#!/usr/bin/env python3
import sys
import torch
from torch.utils.data import Dataset
from random import randint
class SimpleDataset(Dataset):
def __init__(self, x):
self.__iter = None
self.x = x
def __len__(self):
print(" __len__()")
return len(self.x)
def __getitem__(self, key):
print(" __getitem__(0:(1:s))".format(key, key.__class__.__name__))
try:
val = self.x[key]
print(" 0:".format(val))
return val
except:
print(" exc")
raise #IndexError
def __iter__(self):
print(" __iter__()")
self.__iter = iter(self.x)
return self
def __next__(self):
print(" __next__()")
if self.__iter is None:
raise StopIteration
val = next(self.__iter)
if isinstance(self.x, (dict,)): # Special handling for dictionaries
val = self.x[val]
return val
def key_transformer(int_key):
return str(int_key) # You could `return int_key` to see that it also works on your original example
def dataset_example(inner, key_func=None):
if key_func is None:
key_func = lambda x: x
print("\nInner object: 0:".format(inner))
sd = SimpleDataset(inner)
print("Dataset length: 0:d".format(len(sd)))
print("\nIterating (old fashion way):")
for i in range(len(sd)):
print(" 0:: 1:".format(key_func(i), sd[key_func(i)]))
print("\nIterating (Python (iterator protocol) way):")
for element in sd:
print(" 0:".format(element))
print("\nTry building the list:")
l = list(sd)
print(" List: 0:\n".format(l))
def main():
dict_size = 2
for inner, func in [
(torch.randn(2, 2), None),
(key_transformer(i): randint(0, 100) for i in reversed(range(dict_size)), key_transformer), # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
]:
dataset_example(inner, key_func=func)
if __name__ == "__main__":
print("Python 0:s 1:dbit on 2:s\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
main()
print("\nDone.")
输出:
[cfati@CFATI-5510-0:e:\Work\Dev\***\q059091544]> "e:\Work\Dev\VEnvs\py_064_03.07.03_test0\Scripts\python.exe" code00.py Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32 Inner object: tensor([[ 0.6626, 0.1107], [-0.1118, 0.6177]]) __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(int)) tensor([0.6626, 0.1107]) 0: tensor([0.6626, 0.1107]) __getitem__(1(int)) tensor([-0.1118, 0.6177]) 1: tensor([-0.1118, 0.6177]) Iterating (Python (iterator protocol) way): __iter__() __next__() tensor([0.6626, 0.1107]) __next__() tensor([-0.1118, 0.6177]) __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [tensor([0.6626, 0.1107]), tensor([-0.1118, 0.6177])] Inner object: '1': 86, '0': 25 __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(str)) 25 0: 25 __getitem__(1(str)) 86 1: 86 Iterating (Python (iterator protocol) way): __iter__() __next__() 86 __next__() 25 __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [86, 25] Done.
您可能还想检查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET (IterableDataset)。
【讨论】:
太好了,非常感谢!您的回答非常有见地! 不客气!您从一开始就尝试做的事情(尽管我不需要它)是可能的:)【参考方案2】:这并不是一个真正的 pytorch 特定问题,而是一个一般性的 python 问题。
您正在使用list(iterable) 构建一个列表,其中iterable 类是实现sequence semantics 的类。
在此处查看__getitem__
对于序列类型的预期行为(最相关的部分以粗体显示)
object.__getitem__(self, key)
调用以实施评估
self[key]
。对于 sequence 类型,接受的键应该是整数 和切片对象。注意否定的特殊解释 索引(如果类希望模拟序列类型)取决于__getitem__()
方法。如果 key 的类型不合适,TypeError
可能会被引发; 如果值在索引集之外 序列(在对负值进行任何特殊解释之后),IndexError
应该提高。 对于映射类型,如果缺少键(不是 在容器中),应该引发 KeyError。注意:
for
循环期望IndexError
将引发非法 索引以允许正确检测序列的结尾。
这里的问题是,对于序列类型,在使用无效索引调用 __getitem__
的情况下,python 需要一个 IndexError
。看来list
构造函数依赖于这种行为。在您的示例中,当 X
是一个字典时,尝试访问无效密钥会导致 __getitem__
引发 KeyError
而不是预期的,因此不会被捕获并导致列表的构建失败。
根据这些信息,您可以执行以下操作
class datasetTest:
def __init__(self):
self.X = 0: 12, 1:35, 2:99, 3:27, 4:33
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if index < 0 or index >= len(self):
raise IndexError
return self.X[index]
d = datasetTest()
print(list(d))
我不建议在实践中这样做,因为它依赖于您的字典 X
,它只包含整数键 0
、1
、...、len(X)-1
,这意味着它最终的行为就像大多数情况下使用列表,因此您最好只使用列表。
【讨论】:
唯一让我感到奇怪的是为什么 list 和 for 循环有不同的行为。即,为什么 for 循环不会失败,因为它似乎还需要一个 IndexError 来知道何时结束迭代,但是,一个 dict 会抛出一个 KeyError 代替...... for 循环也应该失败。当然,如果您使用 for 循环遍历range(len(b))
,它不会失败,因为您明确地遍历 range
类型。如果您尝试使用 for item in b:
直接迭代 b
,则在示例中使用 dict 案例时应该会遇到相同的错误。
@guest1 你是什么意思列表和for循环有不同的行为?以上是关于在 python3 中:list(iterables) 的奇怪行为的主要内容,如果未能解决你的问题,请参考以下文章