GAN实战:半监督生成对抗网络
Posted 人邮异步社区
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN实战:半监督生成对抗网络相关的知识,希望对你有一定的参考价值。
半监督学习(semi-supervised learning)是GAN在实际应用中最有前途的领域之一。与监督学习(数据集中的每个样本有一个标签)和无监督学习(不使用任何标签)不同,半监督学习只为训练数据集的一小部分提供类别标签。通过内化数据中的隐藏结构,半监督学习努力从标注数据点的小子集中归纳,以有效地对从未见过的新样本进行分类。要使半监督学习有效,标签数据和无标签数据必须来自相同的基本分布。
缺少标签数据集是机器学习研究和实际应用中的主要瓶颈之一。尽管无标签数据非常丰富(互联网实际上就是无标签图像、视频和文本的无限来源),但为它们分配类别标签通常非常昂贵、不切实际且耗时。在 ImageNet 中手工标注 320 万张图像用了两年半的时间。ImageNet是一个标签图像的数据库,在过去的十年中对于图像处理和计算机视觉取得的许多进步均有帮助。[2]
深度学习先驱、美国斯坦福大学教授、百度前首席科学家Andrew Ng认为,训练需要大量标签数据是监督学习的致命弱点。目前,工业中的人工智能应用绝大多数使用监督学习。[3] 缺乏大型标签数据集的一个领域是医学,医学上获取数据(如来自临床试验的结果)通常需要耗费大量的精力和开支,更别说会面临道德伦理和隐私等更严重的问题了。[4] 因此,提高算法从越来越少的标注样本中学习的能力具有巨大的实际意义。
有趣的是,半监督学习可能也是最接近人类学习方式的机器学习方式之一。小学生学习阅读和书写时,老师不必带他们出门旅行,让他们在路上看到成千上万个字母和数字的样本以后,再根据需要纠正他们——就像监督学习算法的运作方式一样。相反,只需要一组样本可供孩子学习字母和数字,然后不管何种字体、大小、角度、照明条件和许多其他条件下,他们能够识别出来。半监督学习旨在按照这种有效的方式教会机器。
作为可用于训练的附加信息的来源,生成模型已被证明有助于提高半监督模型的准确性。不出所料,GAN是最有前途的。2016年,Tim Salimans、Ian Goodfellow和他们在OpenAI的同事仅使用2000个带标签的样本就在街景房屋号码数据集(Street View House Numbers,SVHN)上获得了近94%的准确率。[5] 相比之下,当时在SVHN训练集中对所有73257张图像使用带标签的最佳全监督算法的准确率约为98.40%。[6] 换句话说,半监督GAN的总体准确率与全监督的基准测试非常接近,而训练时使用的标签不到3%。
下面我们来看看Salimans和他的同事是如何在如此短的时间内取得如此大的成就的。
7.1.1 什么是SGAN
半监督生成对抗网络(Semi-Supervised GAN,SGAN)是一种生成对抗网络,其鉴别器是多分类器。这里的鉴别器不只是区分两个类(真和假),而是学会区分
类,其中
是训练数据集中的类数,生成器生成的伪样本增加了一个类。
例如,MNIST手写数字数据集有10个标签(每个数字一个标签,从0到9),因此在此数据集上训练的SGAN鉴别器将预测10+1=11个类。在我们的实现中,SGAN鉴别器的输出将表示为10个类别的概率(之和为1.0)加上另一个表示图像是真还是假的概率的向量。
将鉴别器从二分类器转变为多分类器看似是一个微不足道的变化,但其含义比乍看之下更为深远。我们从一个图7.2所示的SGAN架构开始解释。
图7.2 此SGAN中,生成器输入随机噪声向量
并生成伪样本
。鉴别器接收3种数据输入:来自生成器的伪数据、真实的无标签数据样本
和真实的标签数据样本
,其中
是给定样本对应的标签;然后鉴别器输出分类,以区分伪样本与真实样本区,并为真实样本确定正确的类别。注意,标签数据比无标签数据少得多。实际情况中,这一对比甚至比本图所显示的更明显,标签数据仅占训练数据的一小部分(通常低至1%~2%)
如图7.2所示,与传统GAN相比,区分多个类的任务不仅影响了鉴别器本身,还增加了SGAN架构、训练过程和训练目标的复杂性。
7.1.2 结构
SGAN生成器的目的与原始GAN相同:接收一个随机数向量并生成伪样本,力求使伪样本与训练数据集别无二致。
但是,SGAN鉴别器与原始GAN实现有很大不同。它接收3种输入:生成器生成的伪样本
、训练数据集中无标签的真实样本
和有标签的真实样本
。其中
表示给定样本
的标签。
SGAN鉴别器的目标不是二分类,而是在输入样本为真的情况下,将其正确分类到相应的类中,或将样本作为假的(可以认为是特殊的附加类)排除。
有关这两个SGAN子网络的要点见表7.1。
表7.1 SGAN的生成器网络和鉴别器网络
| 生成器 | 鉴别器 |
输入 | 一个随机数向量
| 鉴别器接收3种输入;
|
输出 | 尽可能令人相信的伪样本
| 表示输入样本属于
个真实类别中的某一个或属于伪样本类别的可能性 |
目标 | 生成与训练数据集数据别无二致的伪样本,以欺骗鉴别器,使之将伪样本分到真实类别 | 学会将正确的类别标签分配给真实的样本,同时将来自生成器的所有样本判别为假 |
7.1.3 训练过程
回想一下,常规GAN通过计算
)和
的损失并反向传播总损失来更新鉴别器的可训练参数,以使损失最小,从而训练鉴别器。生成器通过反向传播鉴别器损失
并寻求使其最大化来进行训练,以便让鉴别器将合成的伪样本错误地分类为真。
为了训练SGAN,除了
)和
,我们还必须计算有监督训练样本的损失:
。这些损失与SGAN鉴别器必须达到的双重目标相对应:区分真伪样本;学习将真实样本正确分类。用原论文中的术语来说,双重目标对应于两种损失:有监督损失(supervised loss)和无监督损失(unsupervised loss)。[7]
7.1.4 训练目标
到目前为止,你看到的GAN变体都是生成模型。它们的目标是生成逼真的数据样本。正因如此,人们最感兴趣的一直是生成器。鉴别器网络的主要目的是帮助生成器提高生成图像的质量。在训练结束时,我们通常会忽略鉴别器,仅使用训练好的生成器来创建逼真的合成数据。
在SGAN中主要关心的反而是鉴别器。训练过程的目标是使该网络成为仅使用一小部分标签数据的半监督分类器,其准确率尽可能接近全监督的分类器(其训练数据集中的每个样本都有标签)。生成器的目标是通过提供附加信息(它生成的伪数据)来帮助鉴别器学习数据中的相关模式,从而提高其分类准确率。训练结束时,生成器将被丢弃,而训练有素的鉴别器将被用作分类器。
至此,我们已经介绍了是什么推动了SGAN的诞生以及它是如何工作的,接下来通过模型的实现来了解它的实际应用。
7.2 教程:SGAN的实现
本教程实现了一个SGAN模型。该模型仅使用100个训练样本即可对MNIST数据集中的手写数字进行分类。在教程的最后,我们将模型的分类准确率与其对应的全监督模型进行了比较,看看半监督学习所取得的进步。
7.2.1 架构图
本教程中实现的SGAN模型的高级示意如图7.3所示,它比本章开头介绍的一般概念图要复杂一些。关键在于(实现)细节。
为了解决区分真实标签的多分类问题,鉴别器使用了softmax函数,该函数给出了在给定数量的类别(本例中为10类)上的概率分布。给一个给定类别标签分配的概率越高,鉴别器就越确信该样本属于这一给定的类。为了计算分类误差,我们使用了交叉熵损失,以测量输出概率与目标独热编码标签之间的差异。
图7.3 本章教程中实现的SGAN的高级示意。生成器将随机噪声转换为伪样本;鉴别器输入有标签的真实图像
、无标签的真实图像
和生成器生成的伪图像
。为了区分真实样本和伪样本,鉴别器使用了sigmoid函数;为了区分真实标签的分类,鉴别器使用了softmax函数
为了输出样本是真还是假的概率,鉴别器使用了sigmoid激活函数,并通过反向传播二元交叉熵损失来训练其参数,这与第3章和第4章中实现的GAN相同。
7.2.2 实现
你可能会注意到,本书许多SGAN实现都是从第4章的DCGAN模型改编而来的。这并不是出于懒惰(嗯,也许是有一点……),而是为了更好地了解SGAN所需的不同修改,且不会干扰到网络无关部分中的实现细节。
本书Github仓库(https://Github.com/
GAN-in-Action/GAN-in-action)中的第7章文件夹提供了完整实现的Jupyter Notebook,其中包括训练进度的可视化等信息。代码是用Python3.6.0、Keras2.1.6和TensorFlow1.8.0版本测试的。为了加快训练时间,我们建议你在GPU上运行模型。
7.2.3 设置
首先导入运行模型需要的所有模块和库,如清单7.1所示。
清单7.1 导入声明
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from keras import backend as K
from keras.datasets import mnist
from keras.layers import (Activation, BatchNormalization, Concatenate, Dense,
Dropout, Flatten, Input, Lambda, Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.utils import to_categorical
指定输入图像的大小、噪声向量
的大小以及半监督分类的真实类别的数量(鉴别器将学习识别每个数字对应的类),如清单7.2所示。
清单7.2 模型输入维度
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels) ⇽--- 输入图像的维度
z_dim = 100 ⇽--- 噪声向量的大小,用作生成器的输入
num_classes = 10 ⇽--- 数据集中类别的数量
7.2.4 数据集
尽管MNIST训练数据集里有50000个有标签的训练图像,但我们仅将其中的一小部分(由num_labeled参数决定)用于训练,并假设其余图像都是无标签的。我们这样来实现这一点:取批量有标签数据时仅从前num_labeled个图像采样,而在取批量无标签数据时从其余(50000 – num_labeled)个图像中采样。
Dataset对象(清单7.3)提供了返回所有num_labeled训练样本及其标签的函数,以及能返回MNIST数据集中所有10000个带标签的测试图像的函数。训练后,我们将使用测试集来评估模型的分类在多大程度上可以推广到以前未见过的样本。
清单7.3 用于训练和测试的数据集
class Dataset:
def __init__(self, num_labeled):
self.num_labeled = num_labeled ⇽--- 训练中使用的有标签图像的数量
(self.x_train, self.y_train), (self.x_test, ⇽--- 加载MINST数据集
self.y_test) = mnist.load_data()
def preprocess_imgs(x):
x = (x.astype(np.float32) - 127.5) / 127.5 ⇽--- 灰度像素值从[0, 255]缩放到[–1, 1]
x = np.expand_dims(x, axis=3) ⇽--- 将图像尺寸扩展到宽×高×通道数
return x
def preprocess_labels(y):
return y.reshape(-1, 1)
self.x_train = preprocess_imgs(self.x_train) ⇽--- 训练
self.y_train = preprocess_labels(self.y_train)
self.x_test = preprocess_imgs(self.x_test) ⇽--- 测试
self.y_test = preprocess_labels(self.y_test)
def batch_labeled(self, batch_size):
idx = np.random.randint(0, self.num_labeled, batch_size) ⇽--- 获取随机批量的有标签图像及其标签
imgs = self.x_train[idx]
labels = self.y_train[idx]
return imgs, labels
def batch_unlabeled(self, batch_size):
idx = np.random.randint(self.num_labeled, self.x_train.shape[0], ⇽--- 获取随机批量的无标签图像
batch_size)
imgs = self.x_train[idx]
return imgs
def training_set(self):
x_train = self.x_train[range(self.num_labeled)]
y_train = self.y_train[range(self.num_labeled)]
return x_train, y_train
def test_set(self):
return self.x_test, self.y_test
在本教程中,我们假设只有100个有标签的MNIST图像用于训练:
num_labeled = 100 ⇽--- 要使用的有标签样本的数量(其余将作为无标签样本使用)
dataset = Dataset(num_labeled)
7.2.5 生成器
SGAN的生成器网络与第4章中DCGAN的相同,如清单7.4所示。生成器使用转置卷积层将输入的随机噪声向量转换为28×28×1图像。
清单7.4 SGAN生成器
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256 * 7 * 7, input_dim=z_dim)) ⇽--- 通过一个全连接层改变输入为一个7×7×256的张量
model.add(Reshape((7, 7, 256)))
model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')) ⇽--- 转置卷积层,张量从7×7×256变为14×14×128
model.add(BatchNormalization()) ⇽--- 批归一化
model.add(LeakyReLU(alpha=0.01)) ⇽--- LeakyReLU激活
model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same')) ⇽--- 转置卷积层,张量从14×14×128变为14×14×64
model.add(BatchNormalization()) ⇽--- 批归一化
model.add(LeakyReLU(alpha=0.01)) ⇽--- LeakyReLU激活
model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same')) ⇽--- 转置卷积层,张量从14×14×64变为28×28×1
model.add(Activation('tanh')) ⇽--- 带tanh激活函数的输出层
return model
7.2.6 鉴别器
鉴别器是SGAN模型中最复杂的部分,它有如下双重目标。
(1)区分真实样本和伪样本。为此,SGAN鉴别器使用了sigmoid函数,输出用于二元分类的概率。
(2)对于真实样本,还要对其标签准确分类。为此,SGAN鉴别器使用了softmax函数,输出概率向量——每个目标类别对应一个。
1.核心鉴别器网络
我们先来定义核心鉴别器网络。清单7.5中的模型与第4章中实现的基于ConvNet的鉴别器相似。实际上,直到3×3×128卷积层,它的批归一化和LeakyReLU激活与之前的一直是完全相同的。
在该层之后添加了一个Dropout,这是一种正则化技术,通过在训练过程中随机丢弃神经元及其与网络的连接来防止过拟合。[8] 这就迫使剩余的神经元减少它们之间的相互依赖,并得到对基础数据更一般的表示形式。随机丢弃的神经元比例由比例参数指定,在本实现中将其设置为0.5,即model.add(Dropout(0.5))。由于SGAN分类任务的复杂性增加,我们使用了Dropout,以提高模型从只有100个有标签的样本中归纳的能力。
清单7.5 SGAN鉴别器
def build_discriminator_net(img_shape):
model = Sequential()
model.add( ⇽--- 卷积层,张量从14×14×64变为14×14×32
Conv2D(32,
kernel_size=3,
strides=2,
input_shape=img_shape,
padding='same'))
model.add(LeakyReLU(alpha=0.01)) ⇽--- LeakyReLU激活函数
model.add( ⇽--- 卷积层,张量从14×14×32变为7×7×64
Conv2D(64,
kernel_size=3,
strides=2,
input_shape=img_shape,
padding='same'))
model.add(BatchNormalization()) ⇽--- 批归一化
model.add(LeakyReLU(alpha=0.01)) ⇽--- LeakyReLU激活函数
model.add( ⇽--- 卷积层,张量从7×7×64变为3×3×128
Conv2D(128,
kernel_size=3,
strides=2,
input_shape=img_shape,
padding='same'))
model.add(BatchNormalization()) ⇽--- 批归一化
model.add(LeakyReLU(alpha=0.01)) ⇽--- LeakyReLU激活函数
model.add(Dropout(0.5)) ⇽--- Dropout
model.add(Flatten()) ⇽--- 将张量展平
model.add(Dense(num_classes)) ⇽--- 与num_classes神经元完全连接的层
return model
注意,Dropout层是在批归一化之后添加的。出于这两种技术之间的相互作用,这种方法已显示出优越的性能。[9]
另外,请注意前面的网络以一个具有10个神经元的全连接层结束。接下来,我们需要定义从这些神经元计算出的两个鉴别器输出:一个用于有监督的多分类(使用softmax),另一个用于无监督的二分类(使用sigmoid)。
2.有监督的鉴别器
清单 7.6 的代码采用了之前实现的核心鉴别器网络,并用于构建鉴别器模型的有监督部分。
清单7.6 SGAN有监督的鉴别器
def build_discriminator_supervised(discriminator_net):
model = Sequential()
model.add(discriminator_net)
model.add(Activation('softmax')) ⇽--- softmax激活函数输出真实类别的预测概率分布
return model
3.无监督的鉴别器
清单7.7在核心鉴别器网络的顶上实现了模型的无监督部分。注意,predict(x)这个函数将10个神经元(来自核心鉴别器网络)的输出转换成一个二分类的真假预测。
清单7.7 SGAN鉴别器无监督部分
def build_discriminator_unsupervised(discriminator_net):
model = Sequential()
model.add(discriminator_net)
def predict(x):
prediction = 1.0 - (1.0 / ⇽--- 将真实类别的分布转换为二元真-假概率
(K.sum(K.exp(x), axis=-1, keepdims=True) + 1.0))
return prediction
model.add(Lambda(predict)) ⇽--- 之前定义的真-假输出神经元
return model
7.2.7 搭建整个模型
接下来,我们将构建并编译鉴别器模型和生成器模型(清单 7.8)。注意,有监督损失和无监督损失分别使用categorical_crossentropy和binary_crossentropy损失函数。
清单7.8 构建模型
def build_gan(generator, discriminator):
model = Sequential()
model.add(generator) ⇽--- 合并生成器模型和鉴别器模型
model.add(discriminator)
return model
discriminator_net = build_discriminator_net(img_shape) ⇽--- 核心鉴别器网络:这些层在有监督和无监督训练中共享
discriminator_supervised = build_discriminator_supervised(discriminator_net) ⇽---
discriminator_supervised.compile(loss='categorical_crossentropy',
metrics=['accuracy'],
optimizer=Adam()) ⇽--- 构建并编译有监督训练鉴别器
discriminator_unsupervised = build_discriminator_unsupervised( ⇽---
discriminator_net)
discriminator_unsupervised.compile(loss='binary_crossentropy',
optimizer=Adam()) ⇽--- 构建并编译无监督训练鉴别器
generator = build_generator(z_dim) ⇽--- 构建生成器
discriminator_unsupervised.trainable = False ⇽--- 生成器训练时,鉴别器参数保持不变
gan = build_gan(generator, discriminator_unsupervised) ⇽---
gan.compile(loss='binary_crossentropy', optimizer=Adam()) ⇽--- 构建并编译鉴别器固定的GAN模型,以训练生成器。注意:鉴别器使用无监督版本
7.2.8 训练
以下伪代码概述了SGAN的训练算法。
SGAN训练算法
对每次训练迭代,执行以下操作。
(1)训练鉴别器(有监督)。
a.随机取小批量有标签的真实样本
。
b.计算给定小批量的
并反向传播多分类损失更新
,以使损失最小化。
(2)训练鉴别器(无监督)。
a.随机取小批量无标签的真实样本
。
b.计算给定小批量的
)并反向传播二元分类损失更新
,以使损失最小化。
c.随机取小批量的随机噪声
生成一小批量伪样本:
。
d.计算给定小批量的
并反向传播二元分类损失更新
以使损失最小化。
(3)训练生成器。
a.随机取小批量的随机噪声
生成一小批量伪样本:
。
b.计算给定小批量的
并反向传播二元分类损失更新
以使损失最大化。
结束
清单7.9实现了SGAN的训练算法。
清单7.9 SGAN的训练算法
supervised_losses = []
iteration_checkpoints = []
def train(iterations, batch_size, sample_interval):
real = np.ones((batch_size, 1)) ⇽--- 真实图像的标签:全为1
fake = np.zeros((batch_size, 1)) ⇽--- 伪图像的标签:全为0
for iteration in range(iterations):
imgs, labels = dataset.batch_labeled(batch_size) ⇽--- 获取有标签样本
labels = to_categorical(labels, num_classes=num_classes) ⇽--- 独热编码标签
imgs_unlabeled = dataset.batch_unlabeled(batch_size) ⇽--- 获取无标签样本
z = np.random.normal(0, 1, (batch_size, z_dim)) ⇽--- 生成一批伪图像
gen_imgs = generator.predict(z)
d_loss_supervised,
accuracy = discriminator_supervised.train_on_batch(imgs, labels) ⇽--- 训练有标签的真实样本
d_loss_real = discriminator_unsupervised.train_on_batch( ⇽--- 训练无标签的真实样本
imgs_unlabeled, real)
d_loss_fake = discriminator_unsupervised.train_on_batch(gen_imgs, fake) ⇽--- 训练伪样本
d_loss_unsupervised = 0.5 * np.add(d_loss_real, d_loss_fake)
z = np.random.normal(0, 1, (batch_size, z_dim)) ⇽--- 生成一批伪样本
gen_imgs = generator.predict(z)
g_loss = gan.train_on_batch(z, np.ones((batch_size, 1))) ⇽--- 训练生成器
if (iteration + 1) % sample_interval == 0:
supervised_losses.append(d_loss_supervised) ⇽---
iteration_checkpoints.append(iteration + 1) ⇽--- 保存鉴别器的有监督分类损失,以便绘制损失曲线
print( ⇽--- 输出训练过程
"%d [D loss supervised: %.4f, acc.: %.2f%%] [D loss" +
" unsupervised: %.4f] [G loss: %f]"
% (iteration + 1, d_loss_supervised, 100 * accuracy,
(d_loss_unsupervised, g_loss))
1.训练模型
之所以使用较小的批量,是因为只有100个有标签的训练样本。我们通过反复试验确定迭代次数:不断增加次数,直到鉴别器的有监督损失趋于平稳,但不要超过稳定点太远(以降低过拟合的风险)。训练模型的代码如清单7.10所示。
清单7.10 训练模型
iterations = 8000 ⇽--- 设置超参数
batch_size = 32
sample_interval = 800
train(iterations, batch_size, sample_interval) ⇽--- 按照指定的迭代次数训练SGAN
2.模型训练和测试准确率
现在,让我们看看SGAN作为分类器的表现吧!在训练过程中,SGAN达到了100%的有监督准确率。尽管这看似很好,但请记住只有100个有标签的样本用于有监督训练——也许模型只是记住了训练数据集。分类器能在多大程度上泛化到训练集中未见过的数据上才是重要的,如清单7.11所示。
清单7.11 测试准确率
x, y = dataset.test_set()
y = to_categorical(y, num_classes=num_classes)
_, accuracy = discriminator_supervised.evaluate(x, y) ⇽--- 在测试集上计算分类准确率
print("Test Accuracy: %.2f%%" % (100 * accuracy))
请尽情欢呼吧!SGAN能够准确分类测试集中大约89%的样本。为了解这有多了不起,我们对比一下SGAN和全监督分类器的性能。
本文摘自《GAN实战》
本书旨在引导对生成对抗网络(GAN)有兴趣的人从头开始学习。本书从最简单的例子开始,介绍一些最具创新性的GAN的实现和技术细节,进而对这些研究进展做出直观的解释,并完整地呈现所涉及的一切内容(不包括最基本的数学和原理),让最前沿的研究变得触手可及。
本书的最终目标是提供必要的知识和工具,让你不仅能全面了解对GAN迄今为止取得的成就,还能有能力自由选择开发新的应用。生成对抗这一模式充满潜力,等着像你这样怀有进取心、想在学术研究和实际应用中做出点成就的人去挖掘!欢迎你加入我们的GAN之旅。
以上是关于GAN实战:半监督生成对抗网络的主要内容,如果未能解决你的问题,请参考以下文章
简述一下生成对抗网络GAN(Generative adversarial nets)模型?