pytorch model()[] 模型对象类型

Posted 脑洞的分析与证明

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch model()[] 模型对象类型相关的知识,希望对你有一定的参考价值。

model = Model() model(input) 直接调用Model类中的forward(input)函数,因其实现了__call__

举个例子

 import math, random
 import numpy as np
 
 import torch
 import torch.nn as nn
 import torch.optim as optim
 import torch.autograd as autograd 
 import torch.nn.functional as F
 USE_CUDA = torch.cuda.is_available()
 Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)
 
 class Encoder(nn.Module):
     def __init__(self, din=32, hidden_dim=128):
         super(Encoder, self).__init__()
         self.fc = nn.Linear(din, hidden_dim)
 
     def forward(self, x):
         embedding = F.relu(self.fc(x))
         return embedding
 
 class AttModel(nn.Module):
     def __init__(self, n_node, din, hidden_dim, dout):
         super(AttModel, self).__init__()
         self.fcv = nn.Linear(din, hidden_dim)
         self.fck = nn.Linear(din, hidden_dim)
         self.fcq = nn.Linear(din, hidden_dim)
         self.fcout = nn.Linear(hidden_dim, dout)
 
     def forward(self, x, mask):
         v = F.relu(self.fcv(x))
         q = F.relu(self.fcq(x))
         k = F.relu(self.fck(x)).permute(0,2,1)
         att = F.softmax(torch.mul(torch.bmm(q,k), mask) - 9e15*(1 - mask),dim=2)
 
         out = torch.bmm(att,v)
         #out = torch.add(out,v)
         out = F.relu(self.fcout(out))
         return out
 
 class Q_Net(nn.Module):
     def __init__(self, hidden_dim, dout):
         super(Q_Net, self).__init__()
         self.fc = nn.Linear(hidden_dim, dout)
 
     def forward(self, x):
         q = self.fc(x)
         return q

 

 class DGN(nn.Module):
     def __init__(self,n_agent,num_inputs,hidden_dim,num_actions):
         super(DGN, self).__init__()
         
         self.encoder = Encoder(num_inputs,hidden_dim)
         self.att_1 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
         self.att_2 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
         self.q_net = Q_Net(hidden_dim,num_actions)
         
     def forward(self, x, mask):
         h1 = self.encoder(x)
         h2 = self.att_1(h1, mask)
         h3 = self.att_2(h2, mask)
         q = self.q_net(h3)
         return q 

 

在监视窗口查看

 model是Tensor类型

故model(input)[0]是取第一个batch

以上是关于pytorch model()[] 模型对象类型的主要内容,如果未能解决你的问题,请参考以下文章

pytorch自动删除之前保存的pt文件

来自 PyTorch 模型的 ONNX 对象,无需导出

Pytorch 保存模型用户警告:无法检索网络类型容器的源代码

Pytorch模型保存与加载,并在加载的模型基础上继续训练

如何使用 PyTorch 模型进行预测?

PyTorch保存和加载模型