50行PyTorch代码实现生成对抗网络(GANs)
Posted 专知
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了50行PyTorch代码实现生成对抗网络(GANs)相关的知识,希望对你有一定的参考价值。
【导读】这是一份非常简单的PyTorch实现GAN教程和代码。文中另附有TensorFlow实现版本。
作者 | Dev Nag
编译 | Xiaowen
2014年,蒙特利尔大学的Ian Goodfellow和他的同事发表了一篇令人惊叹的论文,向世界介绍了生成对抗网络GAN。通过计算图和博弈论的创新结合,他们表明,如果建模能力足够强,两个相互对抗的模型将能够通过普通的反向传播进行协同训练。
模型扮演两个截然不同的角色(也就是对抗)。给定一些真实的数据集R,G是生成器(Generator),试图创建看起来像真实数据的假数据,而D是判别器(Discriminator),从真实的集合或G中获取数据并标记差异。Goodfellow的比喻是,G就像是一个伪造者试图把真实的画与他们的输出相匹配,而D则是侦探的团队,试图分辨出不同之处。(除了在这种情况下,伪造者永远无法看到原始数据,只有D的判断——他们就像盲伪造者。)
在理想情况下,随着时间的推移,D和G都会变得更好,直到G本质上成为真正物品的“主伪造者”,而D则不知所措,“无法区分这两种分布”。
在实践中, Goodfellow已经证明,G能够在原始数据集上执行一种形式的无监督学习,找到某种方式以(可能)更低维的方式来表示这些数据。正如Yann LeCun所指出的,无监督学习是真正的AI的“蛋糕”。
这种强大的技术似乎需要一吨的代码才可以开始,对吧?不。使用PyTorch,我们实际上可以在50行代码中创建一个非常简单的GAN。实际上只有5个因素需要考虑:
R:原始的真实数据集
I:进入生成器的随机噪声
G:试图复制/模仿原始数据集的生成器
D:试图区分G的输出与真实的R的判别器
Loop:实际的“训练”循环,我们教G来欺骗D,D来小心G。
(1)R:在我们的例子中,我们将从最简单的R (钟形曲线)开始。该函数采用均值和标准差,并返回一个函数,该函数提供了具有这些参数的高斯样本数据的正确形状。在我们的样本代码中,我们将使用平均值为4.0,标准差为1.25的数据。
(2)I:对生成器的输入也是随机的,但是为了使我们的工作有些难度,让我们使用统一的分布而不是普遍的分布。这意味着我们的模型G不能简单地移动/缩放输入来复制R,而是必须以非线性的方式重塑数据。
(3)G:生成器是标准前馈图——两个隐藏层,三个线性映射。我们使用ELU(指数线性单元)。G将从 I 获得均匀分布的数据样本并且以某种方式模拟来自R的正态分布的样本。
(4)D:判别码与G的生成码非常相似;一个包含两个隐藏层和三个线性映射的前馈图。它将从R或G中获取样本,并输出0到1之间的单个标量,解释为‘假’与‘真’。这是神经网络所能得到的最大限度的误差。
(5)最后,训练循环在两种模式之间交替进行:第一次用准确的标签训练D关于真实数据vs假数据;然后用不准确的标签来训练G以愚弄D。
即使你以前没见过PyTorch,你也可能知道上图代码的结构。在第一个(绿色)部分,我们把两种类型的数据都给D,并对D的猜测和实际的标签应用一个可微的标准。然后我们显式地调用‘back()’来计算梯度,用于更新d_optimizer.step()中的参数。G是有使用的,但是这里没有训练。
然后,在最后一节(红色)中,我们对G 做了同样的操作,注意,我们也在D中运行G的输出(我们实际上是给伪造者一个测试来练习),但是我们没有在这一步优化或更改D。我们不希望侦探D学习错误的标签。因此,我们只调用g_optimizer.step()。
仅此而已。还有其他一些示例代码,但GAN特有的东西只是这5个组件,没有别的。
在D和G之间进行了几千轮的训练之后,我们得到了什么呢?判别器D很快就好了(G在缓慢地上升),但是一旦它达到了一定的能力水平,G就有了一个值得尊敬的对手,并开始快速改进和提高。
超过20,000次训练回合,G的输出平均值超过4.0,然后回到一个相当稳定的正确范围(左)。同样,标准偏差最初是向错误的方向下降,然后上升到预期的1.25左右(右),匹配R。
让我们来展示G生成的最终分布。
还不错诶。左边的尾巴比右边长一点,但是偏态和峰态看起来应该是高斯分布了。
G几乎完全拟合了原始的数据分布R,而D正在角落里瑟瑟发抖,无法区分G和R。这正是我们想要的。
本文代码在这儿[1]。
最后,提供给大家一些参考资料。Goodfellow的其他GAN工作[2],包括这里适用的小型批处理识别方法。另外还有NIPS2016上一个两小时的演讲教程[3]。对于TensorFlow的用户来说,这里也有一份教程[4]。
参考链接:
1. https://github.com/devnag/pytorch-generative-adversarial-networks
2. https://arxiv.org/pdf/1606.03498.pdf
3. https://channel9.msdn.com/Events/Neural-Information-Processing-Systems-Conference/Neural-Information-Processing-Systems-Conference-NIPS-2016/Generative-Adversarial-Networks
4. http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/
原文链接:
https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f
-END-
专 · 知
人工智能领域26个主题知识资料全集获取与加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!
请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料!
请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~
点击“阅读原文”,使用专知
以上是关于50行PyTorch代码实现生成对抗网络(GANs)的主要内容,如果未能解决你的问题,请参考以下文章