如何使用 pytorch 构建多任务 DNN,例如 100 多个任务?
Posted
技术标签:
【中文标题】如何使用 pytorch 构建多任务 DNN,例如 100 多个任务?【英文标题】:How to use pytorch to construct multi-task DNN, e.g., for more than 100 tasks? 【发布时间】:2020-05-02 22:47:40 【问题描述】:以下是使用 pytorch 为两个回归任务构建 DNN 的示例代码。 forward
函数返回两个输出 (x1, x2)。用于大量回归/分类任务的网络怎么样?例如,100 或 1000 个输出。硬编码所有输出(例如,x1、x2、...、x100)绝对不是一个好主意。有没有一种简单的方法可以做到这一点?谢谢。
import torch
from torch import nn
import torch.nn.functional as F
class mynet(nn.Module):
def __init__(self):
super(mynet, self).__init__()
self.lin1 = nn.Linear(5, 10)
self.lin2 = nn.Linear(10, 3)
self.lin3 = nn.Linear(10, 4)
def forward(self, x):
x = self.lin1(x)
x1 = self.lin2(x)
x2 = self.lin3(x)
return x1, x2
if __name__ == '__main__':
x = torch.randn(1000, 5)
y1 = torch.randn(1000, 3)
y2 = torch.randn(1000, 4)
model = mynet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
for epoch in range(100):
model.train()
optimizer.zero_grad()
out1, out2 = model(x)
loss = 0.2 * F.mse_loss(out1, y1) + 0.8 * F.mse_loss(out2, y2)
loss.backward()
optimizer.step()
【问题讨论】:
related question, 【参考方案1】:您可以(并且应该)使用nn
containers,例如nn.ModuleList
或nn.ModuleDict
来管理任意数量的子模块。
例如(使用nn.ModuleList
):
class MultiHeadNetwork(nn.Module):
def __init__(self, list_with_number_of_outputs_of_each_head):
super(MultiHeadNetwork, self).__init__()
self.backbone = ... # build the basic "backbone" on top of which all other heads come
# all other "heads"
self.heads = nn.ModuleList([])
for nout in list_with_number_of_outputs_of_each_head:
self.heads.append(nn.Sequential(
nn.Linear(10, nout * 2),
nn.ReLU(inplace=True),
nn.Linear(nout * 2, nout)))
def forward(self, x):
common_features = self.backbone(x) # compute the shared features
outputs = []
for head in self.heads:
outputs.append(head(common_features))
return outputs
请注意,在此示例中,每个头都比单个 nn.Linear
层更复杂。
不同“头”的数量(和输出的数量)由参数list_with_number_of_outputs_of_each_head
的长度决定。
重要提示:重要的是使用nn
containers,而不是简单的pythonic 列表/字典来存储所有子模块。否则 pytorch 将难以管理所有子模块。
参见,例如,this answer、this question 和 this one。
【讨论】:
以上是关于如何使用 pytorch 构建多任务 DNN,例如 100 多个任务?的主要内容,如果未能解决你的问题,请参考以下文章
如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它
PyTorch 到 onnx 并与 opencv-dnn 一起使用?