在 python3 中:list(iterables) 的奇怪行为

Posted

技术标签:

【中文标题】在 python3 中:list(iterables) 的奇怪行为【英文标题】:In python3: strange behaviour of list(iterables) 【发布时间】:2020-03-24 06:28:02 【问题描述】:

我有一个关于 python 中迭代行为的具体问题。我的可迭代是 pytorch 中自定义构建的 Dataset 类:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

现在,当我初始化该 datasetTest 类的特定实例时,就会出现奇怪的行为。根据我作为参数 X 传递的数据结构,当我调用 list(datasetTestInstance) 时,它的行为会有所不同。特别是,当传递一个 torch.tensor 作为参数时没有问题,但是当传递一个 dict 作为参数时,它会抛出一个 KeyError。原因是 list(iterable) 不仅调用 i=0, ..., len(iterable)-1,而且调用 i=0, ..., len(iterable)。也就是说,它将迭代直到(包括)索引等于可迭代的长度。显然,这个索引在任何 python 数据结构中都没有定义,因为最后一个元素总是有索引 len(datastructure)-1 而不是 len(datastructure)。如果 X 是 torch.tensor 或列表,则不会出现错误,即使我认为应该是错误。即使对于索引为 len(datasetTestinstance) 的(不存在的)元素,它仍然会调用 getitem,但它不会计算 y=self.X[len(datasetTestInstance]。有谁知道 pytorch 是否在内部以某种方式优雅地处理这个问题?

当将 dict 作为数据传递时,它将在最后一次迭代中抛出错误,此时 x=len(datasetTestInstance)。这实际上是我猜想的预期行为。但为什么这只发生在 dict 而不是 list 或 torch.tensor?

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest(0: 12, 1:35, 2:99, 3:27, 4:33)
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

如果您想更好地理解我所观察到的内容,可以尝试使用该 sn-p 代码。

我的问题是:

1.) 为什么 list(iterable) 会迭代直到(包括)len(iterable)? for 循环不会这样做。

2.) 在作为数据 X 传递的 torch.tensor 或列表的情况下:为什么即使调用索引 len(datasetTestInstance) 的 getitem 方法也不会引发错误,因为它实际上应该超出范围未定义为张量/列表中的索引?或者,换句话说,当到达索引 len(datasetTestInstance) 然后进入 getitem 方法时,究竟发生了什么?它显然不再调用'y = self.X[x]'(否则会有一个IndexError),但它确实进入了getitem方法,我可以看到它从getitem方法中打印索引x。那么这种方法会发生什么?为什么它的行为会根据是 torch.tensor/list 还是 dict 而有所不同?

【问题讨论】:

关于第1点。通常我们使用for item in b:通过for循环迭代可迭代类型。在这种情况下,python 期望 b 引发 IndexError 以指示已到达列表末尾。 (有关特定文档链接,请参阅我的答案)。 你为什么不使用迭代器协议?以这种方式使对象可迭代是非常过时的,并且只保留不破坏向后兼容性,AFAIK 【参考方案1】:

一堆有用的链接:

    [Python 3.Docs]: Data model - Emulating container types [Python 3.Docs]: Built-in Types - Iterator Types [Python 3.Docs]: Built-in Functions - iter(object[, sentinel]) [SO]: Why does list ask about __len__?(所有答案)

关键是 list 构造函数使用 (iterable) 参数的 __len__ ((如果提供)来计算新的容器长度),但随后对其进行迭代(通过迭代器协议)。

由于非常巧合(请记住,dict 支持迭代器协议,这发生在它的键上(这是一个序列)):

您的字典只有 int(还有更多) 它们的值与它们的索引相同(按顺序)

改变以上 2 个项目符号所表达的任何条件,都会使实际的错误更有说服力。

两个对象(dictlist (of tensors))都支持迭代器协议。为了使事情正常工作,您应该将它包装在您的 Dataset 类中,并稍微调整映射类型之一(使用值而不是键)。 代码(key_func 相关部分)有点复杂,但只是为了易于配置(如果您想更改某些内容 - 用于 demo 目的)。 p>

code00.py

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__(0:(1:s))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    0:".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: 0:".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: 0:d".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  0:: 1:".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  0:".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: 0:\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        (key_transformer(i): randint(0, 100) for i in reversed(range(dict_size)), key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python 0:s 1:dbit on 2:s\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

输出

[cfati@CFATI-5510-0:e:\Work\Dev\***\q059091544]> "e:\Work\Dev\VEnvs\py_064_03.07.03_test0\Scripts\python.exe" code00.py
Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32


Inner object: tensor([[ 0.6626,  0.1107],
        [-0.1118,  0.6177]])
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(int))
    tensor([0.6626, 0.1107])
  0: tensor([0.6626, 0.1107])
    __getitem__(1(int))
    tensor([-0.1118,  0.6177])
  1: tensor([-0.1118,  0.6177])

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  tensor([0.6626, 0.1107])
    __next__()
  tensor([-0.1118,  0.6177])
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [tensor([0.6626, 0.1107]), tensor([-0.1118,  0.6177])]


Inner object: '1': 86, '0': 25
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(str))
    25
  0: 25
    __getitem__(1(str))
    86
  1: 86

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  86
    __next__()
  25
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [86, 25]


Done.

您可能还想检查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET (IterableDataset)。

【讨论】:

太好了,非常感谢!您的回答非常有见地! 不客气!您从一开始就尝试做的事情(尽管我不需要它)是可能的:)【参考方案2】:

这并不是一个真正的 pytorch 特定问题,而是一个一般性的 python 问题。

您正在使用list(iterable) 构建一个列表,其中iterable 类是实现sequence semantics 的类。

在此处查看__getitem__ 对于序列类型的预期行为(最相关的部分以粗体显示)

object.__getitem__(self, key)

调用以实施评估 self[key]。对于 sequence 类型,接受的键应该是整数 和切片对象。注意否定的特殊解释 索引(如果类希望模拟序列类型)取决于 __getitem__() 方法。如果 key 的类型不合适,TypeError 可能会被引发; 如果值在索引集之外 序列(在对负值进行任何特殊解释之后), IndexError 应该提高。 对于映射类型,如果缺少键(不是 在容器中),应该引发 KeyError。

注意:for 循环期望 IndexError 将引发非法 索引以允许正确检测序列的结尾。

这里的问题是,对于序列类型,在使用无效索引调用 __getitem__ 的情况下,python 需要一个 IndexError。看来list 构造函数依赖于这种行为。在您的示例中,当 X 是一个字典时,尝试访问无效密钥会导致 __getitem__ 引发 KeyError 而不是预期的,因此不会被捕获并导致列表的构建失败。


根据这些信息,您可以执行以下操作

class datasetTest:
    def __init__(self):
        self.X = 0: 12, 1:35, 2:99, 3:27, 4:33

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

我不建议在实践中这样做,因为它依赖于您的字典 X,它只包含整数键 01、...、len(X)-1,这意味着它最终的行为就像大多数情况下使用列表,因此您最好只使用列表。

【讨论】:

唯一让我感到奇怪的是为什么 list 和 for 循环有不同的行为。即,为什么 for 循环不会失败,因为它似乎还需要一个 IndexError 来知道何时结束迭代,但是,一个 dict 会抛出一个 KeyError 代替...... for 循环也应该失败。当然,如果您使用 for 循环遍历 range(len(b)),它不会失败,因为您明确地遍历 range 类型。如果您尝试使用 for item in b: 直接迭代 b,则在示例中使用 dict 案例时应该会遇到相同的错误。 @guest1 你是什么意思列表和for循环有不同的行为?

以上是关于在 python3 中:list(iterables) 的奇怪行为的主要内容,如果未能解决你的问题,请参考以下文章

Python3-笔记-C-003-函数-enumerate

python3基础-list&tuple

Python3 迭代器

Python 3.5+ 中的 list() 与可迭代解包

python3 迭代

Python3 迭代器和生成器