深度学习系列33:有标签的GAN:CGAN

Posted IE06

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习系列33:有标签的GAN:CGAN相关的知识,希望对你有一定的参考价值。

1. 从GAN到CGAN

GAN的训练数据是没有标签的,如果我们要做有标签的训练,则需要用到CGAN。
对于图像来说,我们既要让输出的图片真实,也要让输出的图片符合标签c。Discriminator输入便被改成了同时输入c和x,输出要做两件事情,一个是判断x是否是真实图片,另一个是x和c是否是匹配的。
在下面两个情况中,左边虽然输出图片清晰,但不符合c;右边输出图片不真实。因此两种情况中D的输出都会是0。

我们来看下简单的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image
 
# 独热编码
# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):
    return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要
 
 
transform =transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5)])
 
dataset = torchvision.datasets.MNIST('data',
                                     train=True,
                                     transform=transform,
                                     target_transform=one_hot,
                                     download=False)
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)
 
 
# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(10, 128 * 7 * 7)
        self.bn1 = nn.BatchNorm1d(128 * 7 * 7)
        self.linear2 = nn.Linear(100, 128 * 7 * 7)
        self.bn2 = nn.BatchNorm1d(128 * 7 * 7)
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 1,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)
 
    def forward(self, x1, x2):
        x1 = F.relu(self.linear1(x1))
        x1 = self.bn1(x1)
        x1 = x1.view(-1, 128, 7, 7)
        x2 = F.relu(self.linear2(x2))
        x2 = self.bn2(x2)
        x2 = x2.view(-1, 128, 7, 7)
        x = torch.cat([x1, x2], axis=1)
        x = F.relu(self.deconv1(x))
        x = self.bn3(x)
        x = F.relu(self.deconv2(x))
        x = self.bn4(x)
        x = torch.tanh(self.deconv3(x))
        return x
 
# 定义判别器
# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(10, 1*28*28)
        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值
 
    def forward(self, x1, x2):
        x1 =F.leaky_relu(self.linear(x1))
        x1 = x1.view(-1, 1, 28, 28)
        x = torch.cat([x1, x2], axis=1)
        x = F.dropout2d(F.leaky_relu(self.conv1(x)))
        x = F.dropout2d(F.leaky_relu(self.conv2(x)))
        x = self.bn(x)
        x = x.view(-1, 128*6*6)
        x = torch.sigmoid(self.fc(x))
        return x
 
# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
 
# 损失计算函数
loss_function = torch.nn.BCELoss()
 
# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
 
 
# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):
    predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i] + 1) / 2, cmap='gray')
        plt.axis("off")
    plt.show()
noise_seed = torch.randn(16, 100, device=device)
 
label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)
print(label_seed)
# print(label_seed_onehot)
 
# 开始训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(150):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader.dataset)
    # 对全部的数据集做一次迭代
    for step, (img, label) in enumerate(dataloader):
        img = img.to(device)
        label = label.to(device)
        size = img.shape[0]
        random_noise = torch.randn(size, 100, device=device)
 
        d_optim.zero_grad()
 
        real_output = dis(label, img)
        d_real_loss = loss_function(real_output,
                                    torch.ones_like(real_output, device=device)
                                    )
        d_real_loss.backward() #求解梯度
 
        # 得到判别器在生成图像上的损失
        gen_img = gen(label,random_noise)
        fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果
        d_fake_loss = loss_function(fake_output,
                                    torch.zeros_like(fake_output, device=device))
        d_fake_loss.backward()
 
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()  # 优化
 
        # 得到生成器的损失
        g_optim.zero_grad()
        fake_output = dis(label, gen_img)
        g_loss = loss_function(fake_output,
                               torch.ones_like(fake_output, device=device))
        g_loss.backward()
        g_optim.step()
 
        with torch.no_grad():
            d_epoch_loss += d_loss.item()
            g_epoch_loss += g_loss.item()
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        if epoch % 10 == 0:
            print('Epoch:', epoch)
            generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)

2. Pix2pixgan

Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件, 需要输入到G和D中。 G的输入是x(x 是需要转换的图片),输出是生成的图片G(x)。 D则需要分辨出x,G(x)和x, y。

这里的生成器模型我们采用U-Net:

在pix2pix中,作者就是把L1 loss 和GAN loss相结合使用,因为作者认为L1 loss 可以恢复图像的低频部分,而GAN loss可以恢复图像的高频部分。



我们看一些代码说明:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 512)
        self.down6 = Downsample(512, 512)
 
        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 256)
        self.up4 = Upsample(512, 128)
        self.up5 = Upsample(256, 64)
 
        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
 
    def forward(self, x):
        x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])
        x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])
        x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])
        x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])
        x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])
        x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])
 
        x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])
        x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])
 
        x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])
        x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])
 
        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x3, x6], dim=1)
 
        x6 = self.up4(x6)
        x6 = torch.cat([x2, x6], dim=1)
 
        x6 = self.up5(x6)
        x6 = torch.cat([x1, x6], dim=1)
 
        x6 = torch.tanh(self.last(x6))
        return x6

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.down1 = Downsample(6, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.conv = nn.Conv2d(256, 512, 3, 1, 1)
        self.bn = nn.BatchNorm2d(512)
        self.last = nn.Conv2d(512, 1, 3, 1)
 
    def forward(self, anno, img):
        x = torch.cat([anno, img], dim=1)  # batch*6*H*W
        x = self.down1(x, is_bn=False)
        x = self.down2(x)
        x = F.dropout2d(self.down3(x))
        x = F.dropout2d(F.leaky_relu(self.conv(x)))
        x = F.dropout2d(self.bn(x))
        x = torch.sigmoid(self.last(x))
        return x

3. CycleGan

pix2pixGAN有一个明显的缺点就是,在进行训练的时候必须提供成对的数据集。比如当我们想生成梵高风格的画时,梵高本人画的作品肯定是相对较少的,这个时候就可以考虑使用cycleGAN。cycleGAN适用于非配对的图像到图像转换:

其原理可以概括为将一类图片转成成另一类图片,比如,现有两个样本空间X、Y,我们希望把X空间中的样本转换成Y空间中的样本。这种转换只是风格上的转换,实际X Y 的内容是不一样的。实际的目标就是学习从X到Y的映射,假设该映射为F,它就对应着GAN中的生成器,F就可以将X中的图片A转换为Y中的图片F(A)。
为了实现这个过程,我们需要两个生成器 G_AB 和 G_BA:

首先是单向loss的组成:
判别 loss: 判别器 D_B 是用来判断输入的图片是否是真实的 B 图片,这个流程和GAN是一致的。

生成 loss:生成器用来重建图片 a,目的是希望生成的图片 G_BA(G_AB(a)) 和原图 a 尽可能的相似,那么可以很简单的采取 L1 loss 或者 L2 loss。除了GAN loss,还包含如下loss:
① cycle-loss:也就是循环一致损失。因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GA-B生成一张假图像,那么要能够使用另外一个生成器 GB-A来努力恢复成原始图像。此过程必须满足循环一致性

② 等价loss:我们要求 G A B ( b ) = b G_AB(b)=b GAB(b)=b,以及 G B A ( a ) = a G_BA(a)=a GBA(a)=a

下面来看下示例代码:
获取苹果橙子数据:

# 加载训练数据
apples_path = glob.glob('data/trainA/*.jpg')
oranges_path = glob.glob('data/trainB/*.jpg')
 
 
transform = transforms.Compose([transforms.ToTensor(),  # 0-1归一化
                                transforms.Normalize(0.5, 0.5),  # -1,1])
 
class AppleOrangeDataset(data.Dataset):
    def __init__(self, img_path):
        self.img_path = img_path
 
    def __getitem__(self, index):
        img_path = self.img_path[index]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)
        return pil_img
    def __len__(self):
        return len(self.img_path)
 
apple_dataset = AppleOrangeDataset(apples_path)
oranges_dataset = AppleOrangeDataset(oranges_path)

基于Unet结构定义上 / 下采样模块,接着定义生成器:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 512)
        self.down6 = Downsample(512, 512)
 
        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 256)
        self.up4 = Upsample(512, 128)
        self.up5 = Upsample(256, 64)
 
        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
 
    def forward(self, x):
        x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])
        x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])
        x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])
        x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])
        x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])
        x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])
 
        x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])
        x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])
 
        x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])
        x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])
 
        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x3, x6], dim=1)
 
        x6 = self.up4(x6)
        x6 = torch.cat([x2, x6], dim=1)
 
        x6 = self.up5(x6)
        x6 = torch.cat([x1, x6], dim=1)
 
        x6 = torch.tanh(self.last(x6))
        return x6

接下来是鉴别器:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)             # 128
        self.down2 = Downsample(64, 128)           # 64
        self.last = nn.Conv2d(128, 1, 3)
 
    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x = torch.sigmoid(self.last(x))
        return x

我们需要定义两个生成器和两个鉴别器:


gen_AB = Generator().to(device)
gen_BA = Generator().to(device)
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 同时对两个生成器进行优化
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()),
                                 lr=2e-4, betas=(0.5, 0.999))
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))

训练过程如下:

D_loss = []  # 记录训练过程中判别器loss变化
G_loss = []  # 记录训练过程中生成器loss变化
 
# 开始训练
for epoch in range(50):
    D_epoch_loss = 0
    G_epoch_loss = 0
    for step, (real_A, real_B) in enumerate(zip(apples_dl, oranges_dl)):
        # GAN 训练
        gen_optimizer.zero_grad()
 
        # identity loss
        same_B = gen_AB(real_B)
        identity_B_loss = l1loss_fn(same_B, real_B)
        same_A = gen_BA(real_A)
        identity_A_loss = l1loss_fn(same_A, real_A)
 
        # GAN loss
        fake_B = gen_AB(real_A)
        D_pred_fake_B = dis_B(fake_B)
        gan_loss_AB = bceloss_fn(D_pred_fake_B,
                                 torch.ones_like(D_pred_fake_B, device=device))
 
        fake_A = gen_BA(real_B)
        D_pred_fake_A = dis_A(fake_A)
        gan_loss_BA = bceloss_fn(D_pred_fake_A,
                                 torch.ones_like(D_pred_fake_A, device=device))
 
        # cycle consistanse loss
        recovered_A = gen_BA(fake_B)
        cycle_loss_ABA = l1loss_fn(recovered_A, real_A)
 
        recovered_B = gen_AB(fake_A)
        cycle_loss_BAB = l1loss_fn(recovered_B, real_B)
 
        # total_loss
        g_loss = (identity_B_loss + identity_A_loss + gan_loss_AB + gan_loss_BA
                  + cycle_loss_ABA + cycle_loss_BAB)
 
        g_loss.backward()
        gen_optimizer.step()
 
        # dis_A 训练
        dis_A_optimizer.zero_grad()
        dis_A_real_output = dis_A(real_A)  # 判别器输入真实图片
        dis_A_real_loss = bceloss_fn(dis_A_real_output,
                                     torch.ones_like(dis_A_real_output, device=device))
 
        dis_A_fake_output = dis_A(fake_A.detach())  # 判别器输入生成图片
        dis_A_fake_loss = bceloss_fn(dis_A_fake_output,
                                     torch.zeros_like(dis_A_fake_output, device=device))
 
        dis_A_loss = (dis_A_real_loss + dis_A_fake_loss) * 0.5
 
        dis_A_loss.backward()
        dis_A_optimizer.step()
 
        # dis_B 训练
        dis_B_optimizer.zero_grad()
        dis_B_real_output = dis_B(real_B)  # 判别器输入真实图片
        dis_B_real_loss = bceloss_fn(dis_B_real_output,
                                     torch.ones_like(dis_B_real_output, device=device))
 
        dis_B_fake_output = dis_B(fake_B.detach())  # 判别器输入生成图片
        dis_B_fake_loss = bceloss_fn(dis_B_fake_output,
                                     torch.zeros_like(dis_B_fake_output, device=device))
 
        dis_B_loss = (dis_B_real_loss + dis_B_fake_loss) * 0.5
 
        dis_B_loss.backward()
        dis_B_optimizer.step()

以上是关于深度学习系列33:有标签的GAN:CGAN的主要内容,如果未能解决你的问题,请参考以下文章

深度学习生成对抗网络GAN|GANWGANWGAN-UPCGANCycleGANDCGAN

深度学习8 GAN生成对抗网络

深度学习8 GAN生成对抗网络

GAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN

GAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN

GAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN