如何创建模块列表列表
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何创建模块列表列表相关的知识,希望对你有一定的参考价值。
创建PyTorch模块列表的python列表可以吗?例如,如果我想在一个层中包含几个Conv1d,然后在另一个层中使用不同的Conv1d。在每一层中,我需要根据层号对输出进行不同的操作。构建此模块列表的“ python-list”的正确方法是什么?
这种方式:
class test(nn.Module):
def __init__(...):
self.modulelists = []
for i in range(4):
self.modulelists.append(nn.ModuleList([nn.Conv1d(10, 10, kernel_size=5) for _ in range(5)]))
或这种方式:
class test(nn.Module):
def __init__(...):
self.modulelists = nn.ModuleList()
for i in range(4):
self.modulelists.append(nn.ModuleList([nn.Conv1d(10, 10, kernel_size=5) for _ in range(5)]))
谢谢
答案
您需要正确注册网络的所有子模块,以便pytorch可以访问其参数,缓冲区等。仅当使用正确的containers时,才能完成此操作。如果您将子模块存储在简单的pythonic列表中,那么pytorch将不知道那里有子模块,它们将被忽略。
因此,如果您使用简单的pythonic列表来存储子模块,例如,当您调用model.cuda()
时,列表中子模块的参数将not传输到GPU,而是保留在CPU上。如果调用model.parameters()
将所有可训练的参数传递给优化器,则pytorch将检测到所有子模块参数,因此优化器将不会“看到”它们。
以上是关于如何创建模块列表列表的主要内容,如果未能解决你的问题,请参考以下文章