pytorch 中分类网络损失函数

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 中分类网络损失函数相关的知识,希望对你有一定的参考价值。

参考技术A 如图搭建简单的分类网络,以二分类为例:

2,10,2分别代表:输入的特征数,隐藏神经元的个数,输出的概率(one-hot编码)

prediction=net(x):概率可以为负数

[0.6,-0.1]:表示预测为0,最大概率的索引为0

[-0.08,6.5]:表示预测为1,最大概率的索引为1

loss_func=torch.nn.CrossEntropyLoss()

loss_func中输入的是直接从网络中输出的量(prediction=net(x))

为了查看其预测的标签,通过打印这句话即可:

print(torch.max(F.softmax(prediction), 1)[1])

1、F.softmax(prediction):将二分类概率都变成0-1之间的数,且相加为1

[0.3,0.7]:预测为1

[0.6,0.4]:预测为0

2、torch.max(F.softmax(prediction), 1):

torch.max(out,1):按行取最大,每行的最大值放到一个矩阵中

torch.max(out,0):按列取最大,每列的最大值放到一个矩阵中

这句话最后返回两个tensor:

values=tensor(最大的概率值)

indices=tensor(最大概率值的索引)

3、torch.max(F.softmax(prediction), 1)[1]:

只保留indices的tensor

实例:

*注:在有的地方我们会看到torch.max(a, 1).data.numpy()的写法,这是因为在早期的pytorch的版本中,variable变量和tenosr是不一样的数据格式,variable可以进行反向传播,tensor不可以,需要将variable转变成tensor再转变成numpy。现在的版本已经将variable和tenosr合并,所以只用torch.max(a,1).numpy()就可以了。

参考链接:https://www.jianshu.com/p/3ed11362b54f

如何在一个模型中训练多个损失,并在 pytorch 中选择冻结网络的某些部分?

【中文标题】如何在一个模型中训练多个损失,并在 pytorch 中选择冻结网络的某些部分?【英文标题】:How to train multiple losses in one model with freezing some part of the network optionally in pytorch? 【发布时间】:2020-10-09 11:49:20 【问题描述】:

我是 pytorch 的新手。 我正在实现一个具有两个分类器(黄色和紫色)​​的网络,如figure 所示。 问题是我想在网络训练黄色分类器时冻结红色部分,在网络训练紫色分类器时解冻红色部分。

我想实现的简短代码如下

# x is input and y_yellow, y_purple are labels of yellow and purple classifiers respectively.
criterion = CrossEntropyLoss()
opt = SGD()
model = my_model()

opt.zero_grad()
yellow_out, purple_out = model(x)

# freeze red part requires_grad = False

yellow_loss = criterion(yellow_out, y_yellow)
yellow_loss.backward()
opt.step()

opt.zero_grad()
# unfreeze red part requires_grad = True

purple_loss = criterion(purple_out, y_purple)
purple_loss.backward()
opt.step

请告诉我实现这个想法的确切方法。

代码顺序对吗?

我是否正确使用了 zero_grad?

我错过了什么吗?

我需要使用任何可选参数吗?

【问题讨论】:

【参考方案1】:

最简单的策略是不将红色模块的参数传递给优化器opt。或者,您可以将红色模块的参数设置为requires_gradFalse

它看起来像:

for param in red_module.parameters():
    param.requires_grad = False

让我知道这是否适合你。

【讨论】:

以上是关于pytorch 中分类网络损失函数的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch CIFAR10图像分类 ResNet篇

Pytorch CIFAR10图像分类 MobileNetv2篇

Pytorch CIFAR10图像分类 MobileNetv2篇

如何在一个模型中训练多个损失,并在 pytorch 中选择冻结网络的某些部分?

Pytorch CIFAR10图像分类 DenseNet篇

Pytorch CIFAR10图像分类 AlexNet篇