从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突
Posted
技术标签:
【中文标题】从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突【英文标题】:Returning indices from pytorch Dataset: Function to alter __getitem__ results in metaclass conflict 【发布时间】:2021-07-20 14:43:14 【问题描述】:我有多个类(用于不同的数据集)继承自 pytorch 的 Dataset 类。它们具有一般结构,如下所示:
from torch.utils.data import Dataset
class SomeDataset(Dataset):
def __init__(self, data, labels):
super(SomeDataset, self).__init__()
self.data = data
self.labels = labels
self.__name__ = 'SomeDataset'
def __getitem__(self, index):
return 'data': self.data[index], 'label': self.labels[index]
def __len__(self):
return len(data)
最近我意识到在批处理时跟踪传递给 Dataloader 的标签会很有好处,所以在谷歌搜索如何执行此操作时,我遇到了this thread,这是我调整代码以编写此代码的地方功能:
def return_indices(dataset_class):
def __getitem__(self, index):
return 'index':1, **dataset_class.__getitem__(self, index)
return type(dataset_class.__name__, (dataset_class, ), '__getitem__': __getitem__)
我以前从未见过type
像这样使用,但在谷歌上搜索后,它有些 有意义,所以我试了一下。不幸的是,这导致了这个错误:
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
这导致了更多的谷歌搜索,即使我开始掌握元类是什么以及它们是如何使用的,我仍然无法弄清楚这种方法有什么问题或如何解决它 -而且我开始认为将这个功能重写到我的数据集类中可能会更容易,而不是使用一些简洁的包装器来为我做这件事。任何人都可以权衡我缺少的任何东西吗?
【问题讨论】:
你不能在方法中使用__getitem__
【参考方案1】:
这样做:
def return_indices(dataset_class):
def __getitem__(self, index):
return 'index':1, **dataset_class.__getitem__(self, index)
metacls = type(dataset_class)
return metacls(dataset_class.__name__, (dataset_class, ), '__getitem__': __getitem__)
发生了什么:正如您所发现的,对 type
的 3 参数调用是在 Python 中以编程方式创建新类的方法,而不需要“类”语句及其主体。
但是type
是“基本元类”——虽然它的实例将是普通类,但它也会将你正在创建的类的元类“硬编码”到自身——相比之下,使用class
语句将使Python 在您正在创建的类的基础中搜索合适的元类。
只需使用您的派生类元类(通过单参数形式的类型,如上,或通过类的__class__
属性,如dataset_class.__class__
)。
使用 this 作为 callable 代替 type 将把它自己作为元类,并且事情应该可以工作。
注意:由于元类还有更多机制,例如 __prepare__
,因此仅调用元类而不是 type
并不总是有效 - 正确的通用方法涉及调用types.prepare_class
和 types.new_class
并有一个回调来执行相当于类语句体中的类体执行。大多数情况下不需要这样做。
【讨论】:
这是有道理的,但是当我运行你的 sn-p 时,我得到"TypeError: __init__() takes 3 positional arguments but 4 were given"
,我猜如果我从 SomeDataset 的定义中更简单的东西继承,它就不会出现。我会调查types.prepare_class
和types.new_class
看看是否有帮助。谢谢!
我检查了一个“普通的元类”——type
单独调用解决了元类——这个 sn-p 也可以工作。还有更多关于 pytorch 元类的内容 - 如果您不成功,我可以直接查看 pytorch 的操作方法。以上是关于从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突的主要内容,如果未能解决你的问题,请参考以下文章