torch_09_GAN
Posted shuangcao
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch_09_GAN相关的知识,希望对你有一定的参考价值。
1.生成对抗网络
让两个网络相互竞争,通过生成网络来生成假的数据,对抗网络通过判别器判别真伪,最后希望生成网络生成的数据能够以假乱真骗过判别器
2.生成模型
在生成对抗网络中,不再是将图片输入编码器得到隐含向量然后生成图片,而是随机初始化一个隐含向量,根据变分自动编码器的特点,初始化一个正态分布的隐含向量,通过类似解码的过程,将它映射到一个更高的维度,最后生成一个与输入数据相似的数据,这就是假的数据。生成对抗网络过程通过计算对抗过程来计算损失函数。
3.对抗过程
对抗过程是一个判断真假的判别器,相当于一个二分类问题。输入一张图片,如果是真的图片希望判别器输出为1,假的图片输出为0,这与图片的标签没有关系。
4.训练过程
先优化判别器,将真的图片和假的图片都输入给判别模型,让判别器计算损失,不断优化,让判别器能够区分出真的图片和假的图片,真的图片为1,假的为0.
然后优化生成器。固定判别器的参数,不断优化生成器,使生成器的生成的结果传递给判别器之后,尽可能的接近1,调整损失函数。
5.判别模型训练
开始需要自己创建label,真实的数据时1,生成的假的数据时0,然后将真实的数据输入给判别器,计算loss值,将假的数据输入判别器得到loss,将这两个loss加起来得到总的loss,然后反向传播更新参数能够得到一个优化好的判别器。
6.生成器训练
一个随机隐含向量通过生成网络得到了一个假的数据,然后希望假的数据经过判别模型尽可能和真实label接近,通过g_loss = criterion(output,real_label)实现,然后反向传播去优化生成器的参数,在这个过程中,判别器的参数不再发生变化,否则生成器永远无法骗过优化的判别器。
以上是关于torch_09_GAN的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch:_thnn_nll_loss_forward 未针对类型 torch.LongTensor 实现