在 pytorch 模型中保存嵌入层

Posted

技术标签:

【中文标题】在 pytorch 模型中保存嵌入层【英文标题】:save embedding layer in pytorch model 【发布时间】:2020-04-19 02:18:56 【问题描述】:

我有这个模型:

class model(nn.Module):
    def __init__(self):
      super().__init__()
      self.conv1 = nn.Conv2d(in_channels=12,out_channels=64,kernel_size=3,stride= 1,padding=1)
      # self.conv2 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride= 1,padding=1)
      self.fc1 = nn.Linear(24576, 128)
      self.bn = nn.BatchNorm1d(128)
      self.dropout1 = nn.Dropout2d(0.5)
      self.fc2 = nn.Linear(128, 10)
      self.fc3 = nn.Linear(10, 3)

    def forward(self, x):
      x = F.relu(self.conv1(x))
      # x = F.relu(self.conv2(x))
      x = F.max_pool2d(x, (2,2))
      # print(x.shape)
      x = x.view(-1,24576)
      x = self.bn(F.relu(self.fc1(x)))
      x = self.dropout1(x)
      embeding_stage = F.relu(self.fc2(x))
      x = self.fc3(embeding_stage)

      return x

我想保存 embeding_stage 层,就像我在这里保存模型一样:

model = model()
torch.save(model.state_dict(), 'C:\project\count_speakers\model_pytorch.h5')

谢谢, 阿亚尔

【问题讨论】:

保存embeding_stage是什么意思?你想保存self.fc2()层吗? 为什么torch.save 不能满足您的需求? 【参考方案1】:

我不确定我是否理解您所说的“保存 embedding_stage 层”的意思,但如果您想保存 fc2 或 fc3 之类的东西,那么您可以使用 torch.save() 来做到这一点。 例如:保存fc3:torch.save(model.fc3),'C:\...\fc3.pt')

编辑:

Op 想要得到 embedding_stage 的输出。 您可以通过多种方式做到这一点:

使用model.load_state_dict(torch.load('C:\...\model_pytorch.h5')) 加载您的模型 然后model = nn.Sequential(*list(model.children())[:-1])。模型的输出是embeding_stage。

创建一个Model2(nn.Module),与您的第一个Model() 完全相同,但将def forward(self, x): 中的return x 替换为return embeding_stage。然后像这样将第一个模型的状态加载到第二个模型中:model2.load_state_dict(torch.load('C:\...\model_pytorch.h5')) 像这样 fc3 会被加载,但不会被使用。 model2(x) 的输出将是 embeding_stage。

【讨论】:

我的意思是保存后我想再次加载模型然后测试模型但没有得到完整模型的结果而是提前了一步,即:val_outputs = model(val_inputs) 但是val_output 大小为 10,如 fc2

以上是关于在 pytorch 模型中保存嵌入层的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 模型保存与加载 cpu转GPU

pytorch 模型保存与加载 cpu转GPU

pytorch如何给预训练模型添加新的层

PyTorch:保存权重和模型定义

在 pytorch 的嵌入层中“究竟”发生了啥?

如何正确地为 PyTorch 中的嵌入、LSTM 和线性层提供输入?