从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突

Posted

技术标签:

【中文标题】从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突【英文标题】:Returning indices from pytorch Dataset: Function to alter __getitem__ results in metaclass conflict 【发布时间】:2021-07-20 14:43:14 【问题描述】:

我有多个类(用于不同的数据集)继承自 pytorch 的 Dataset 类。它们具有一般结构,如下所示:

from torch.utils.data import Dataset

class SomeDataset(Dataset):

    def __init__(self, data, labels):
        super(SomeDataset, self).__init__()
        self.data = data
        self.labels = labels
        self.__name__ = 'SomeDataset'

    def __getitem__(self, index):
        return 'data': self.data[index], 'label': self.labels[index]

    def __len__(self):
        return len(data)

最近我意识到在批处理时跟踪传递给 Dataloader 的标签会很有好处,所以在谷歌搜索如何执行此操作时,我遇到了this thread,这是我调整代码以编写此代码的地方功能:

def return_indices(dataset_class):
    
    def __getitem__(self, index):
        return 'index':1, **dataset_class.__getitem__(self, index)

    return type(dataset_class.__name__, (dataset_class, ), '__getitem__': __getitem__)

我以前从未见过type 像这样使用,但在谷歌上搜索后,它有些 有意义,所以我试了一下。不幸的是,这导致了这个错误:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases

这导致了更多的谷歌搜索,即使我开始掌握元类是什么以及它们是如何使用的,我仍然无法弄清楚这种方法有什么问题或如何解决它 -而且我开始认为将这个功能重写到我的数据集类中可能会更容易,而不是使用一些简洁的包装器来为我做这件事。任何人都可以权衡我缺少的任何东西吗?

【问题讨论】:

你不能在方法中使用__getitem__ 【参考方案1】:

这样做:

def return_indices(dataset_class):
    
    def __getitem__(self, index):
        return 'index':1, **dataset_class.__getitem__(self, index)
    metacls = type(dataset_class)
    return metacls(dataset_class.__name__, (dataset_class, ), '__getitem__': __getitem__)

发生了什么:正如您所发现的,对 type 的 3 参数调用是在 Python 中以编程方式创建新类的方法,而不需要“类”语句及其主体。

但是type 是“基本元类”——虽然它的实例将是普通类,但它也会将你正在创建的类的元类“硬编码”到自身——相比之下,使用class 语句将使Python 在您正在创建的类的基础中搜索合适的元类。

只需使用您的派生类元类(通过单参数形式的类型,如上,或通过类的__class__ 属性,如dataset_class.__class__)。

使用 this 作为 callable 代替 type 将把它自己作为元类,并且事情应该可以工作。

注意:由于元类还有更多机制,例如 __prepare__,因此仅调用元类而不是 type 并不总是有效 - 正确的通用方法涉及调用types.prepare_classtypes.new_class 并有一个回调来执行相当于类语句体中的类体执行。大多数情况下不需要这样做。

【讨论】:

这是有道理的,但是当我运行你的 sn-p 时,我得到 "TypeError: __init__() takes 3 positional arguments but 4 were given",我猜如果我从 SomeDataset 的定义中更简单的东西继承,它就不会出现。我会调查types.prepare_classtypes.new_class 看看是否有帮助。谢谢! 我检查了一个“普通的元类”——type 单独调用解决了元类——这个 sn-p 也可以工作。还有更多关于 pytorch 元类的内容 - 如果您不成功,我可以直接查看 pytorch 的操作方法。

以上是关于从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch学习笔记 3.数据集和数据加载器

PyTorch学习笔记 3.数据集和数据加载器

pytorch土堆pytorch教程学习torchvision 中的数据集的使用

[基于Pytorch的MNIST识别02]用户数据集的读取

如何在 Pytorch 中测试自定义数据集?

如何在 pytorch 中处理大型数据集