如何创建模块列表列表

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何创建模块列表列表相关的知识,希望对你有一定的参考价值。

创建PyTorch模块列表的python列表可以吗?例如,如果我想在一个层中包含几个Conv1d,然后在另一个层中使用不同的Conv1d。在每一层中,我需要根据层号对输出进行不同的操作。构建此模块列表的“ python-list”的正确方法是什么?

这种方式:

    class test(nn.Module):
        def __init__(...):
            self.modulelists = []
            for i in range(4):
                self.modulelists.append(nn.ModuleList([nn.Conv1d(10, 10, kernel_size=5) for _ in range(5)]))

或这种方式:

    class test(nn.Module):
        def __init__(...):
            self.modulelists = nn.ModuleList()
            for i in range(4):
                self.modulelists.append(nn.ModuleList([nn.Conv1d(10, 10, kernel_size=5) for _ in range(5)]))

谢谢

答案

您需要正确注册网络的所有子模块,以便pytorch可以访问其参数,缓冲区等。仅当使用正确的containers时,才能完成此操作。如果您将子模块存储在简单的pythonic列表中,那么pytorch将不知道那里有子模块,它们将被忽略。

因此,如果您使用简单的pythonic列表来存储子模块,例如,当您调用model.cuda()时,列表中子模块的参数将not传输到GPU,而是保留在CPU上。如果调用model.parameters()将所有可训练的参数传递给优化器,则pytorch将检测到所有子模块参数,因此优化器将不会“看到”它们。

以上是关于如何创建模块列表列表的主要内容,如果未能解决你的问题,请参考以下文章

如何从片段内的列表视图打开链接网址?

如何使列表视图出现在片段中?

如何在片段中填充列表视图?

如何膨胀由 Android Studio 向导在 Activity 中创建的片段(列表)?

如何将列表视图中的数据从一个片段发送到另一个片段

如何创建片段以重复变量编号中的代码行