GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN

Posted JiYH

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN相关的知识,希望对你有一定的参考价值。

pixtopix(像素到像素)

原文连接:https://arxiv.org/pdf/1611.07004.pdf
输入一个域的图片转换为另一个域的图片(白天照片转成黑夜)
如下图,输入标记图片,输出真实图片缺点就是训练集两个域的图片要一一对应,所以叫pixtopix,

网络结构有点复杂,用到了语义分割的UNET网络结构

数据集:
地址忘了,也是官方的,想起来补
代码:这里是建筑物labels to facade的例子

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image

# jpg是原始图片
images_path = glob.glob(r'base\\*.jpg')
annos_path = glob.glob(r'base\\*.png')
# png是分割的图片

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.Normalize(0.5, 0.5)
])


class CMP_dataset(data.Dataset):
    def __init__(self, imgs_path, annos_path):
        self.imgs_path = imgs_path
        self.annos_path = annos_path

    def __getitem__(self, item):
        img_path = self.imgs_path[item]
        anno_path = self.annos_path[item]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)

        anno_img = Image.open(anno_path)
        anno_img = anno_img.convert('RGB')
        pil_anno = transform(anno_img)
        return pil_anno, pil_img

    def __len__(self):
        return len(self.imgs_path)


dataset = CMP_dataset(images_path, annos_path)
batchsize = 32
dataloader = data.DataLoader(dataset,
                             batch_size=batchsize,
                             shuffle=True)

annos_batch, images_batch = next(iter(dataloader))

for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
    anno = (anno.permute(1, 2, 0).numpy()+1)/2
    img = (img.permute(1, 2, 0).numpy()+1)/2
    plt.subplot(3, 2, i*2+1)
    plt.title('input_img')
    plt.imshow(anno)

    plt.subplot(3, 2, i*2+2)
    plt.title('output_img')
    plt.imshow(img)
plt.show()

# 定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 2, 1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x


# 定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_drop=False):
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x


# 定义生成器,包含6个下采样,5上采样,1输出
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)   # 64,128,128
        self.down2 = Downsample(64, 128)  # 128,64,64
        self.down3 = Downsample(128, 256)  # 256,32,32
        self.down4 = Downsample(256, 512)  # 512, 16,16
        self.down5 = Downsample(512, 512)  # 512,8,8
        self.down6 = Downsample(512, 512)  # 512, 4,4

        self.up1 = Upsample(512, 512)      # 512 ,8,8
        self.up2 = Upsample(1024, 512)    # 512, 16,16
        self.up3 = Upsample(1024, 256)   # 256, 32,32
        self.up4 = Upsample(512, 128)   # 128,64,64
        self.up5 = Upsample(256, 64)   # 64,128,128

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
    def forward(self,x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6, is_drop=True)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6


# 定义判别器 输入anno + img
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(6, 64)  # 64*128*128
        self.down2 = Downsample(64, 128)  # 128*64*64
        self.conv1 = nn.Conv2d(128, 256, 3)
        self.bn1 = nn.BatchNorm2d(256)
        self.conv2 = nn.Conv2d(256, 1, 3)

    def forward(self, anno, img):
        x = torch.cat([anno, img], axis=1)  # batch*6*h*w
        x = self.down1(x, is_bn=False)
        x = self.down2(x)
        x = F.dropout2d(self.bn1(F.leaky_relu(self.conv1(x))))
        x = torch.sigmoid(self.conv2(x))   # batch*1* 60*60
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print('using cuda:', torch.cuda.get_device_name(0))
else:
    print(device)

Gen = Generator().to(device)
Dis = Discriminator().to(device)

d_optimizer = torch.optim.Adam(Dis.parameters(), lr=1e-3, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(Gen.parameters(), lr=1e-3, betas=(0.5, 0.999))
# loss
# cgan损失
loss_fn = torch.nn.BCELoss()
# L1-loss 后面计算,求差绝对值的求和
# 绘图
def generator_images(model, test_anno, test_real):
    prediction = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy()
    test_anno = test_anno.permute(0, 2, 3, 1).detach().cpu().numpy()

    test_real = test_real.permute(0, 2, 3, 1).detach().cpu().numpy()
    plt.figure(figsize=(10, 10))
    display_list = [test_anno[0], test_real[0], prediction[0]]
    title = ['input', 'ground truth', 'output']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.show()

# 加载extend为测试
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')

test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
test_daloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batchsize
)
# 返回一个批次

annos_batch, images_batch = next(iter(dataloader))

plt.figure(figsize=(6, 10))
for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
    anno = (anno.permute(1, 2, 0).numpy()+1)/2
    img = (img.permute(1, 2, 0).numpy()+1)/2
    plt.subplot(3, 2, i*2+1)
    plt.title('input_img')
    plt.imshow(anno)

    plt.subplot(3, 2, i*2+2)
    plt.title('output_img')
    plt.imshow(img)
plt.show()

annos_batch, images_batch = annos_batch.to(device), images_batch.to(device)
LAMBDA = 7  # L1损失权重

D_loss = []
G_loss = []
for epoch in range(300):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)
    for step, (annos, imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        annos = annos.to(device)

        d_optimizer.zero_grad()
        disc_real_output = Dis(annos, imgs)  # 输入真实成对图片
        d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output,
                                                                device=device)
                              )
        d_real_loss.backward()

        gen_output = Gen(annos)
        dis_gen_output = Dis(annos, gen_output.detach())
        d_fake_loss = loss_fn(dis_gen_output, torch.zeros_like(dis_gen_output,
                                                               device=device)
                              )
        d_fake_loss.backward()

        disc_loss = d_real_loss + d_fake_loss

        d_optimizer.step()

        disc_gen_out = Dis(annos, gen_output)
        gen_loss_crossentropyloss = loss_fn(disc_gen_out,
                                            torch.ones_like(disc_gen_out,
                                                            device=device)
                                            )
        gen_l1_loss = torch.mean(torch.abs(gen_output - imgs))
        gen_loss = LAMBDA * gen_l1_loss + gen_loss_crossentropyloss
        gen_loss.backward()
        g_optimizer.step()

        with torch.no_grad():
            D_epoch_loss += disc_loss.item()
            G_epoch_loss += gen_loss.item()
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        print('Epoch', epoch)
        generator_images(Gen, annos_batch, images_batch)

给动漫素描自动上色的(AI上色)移步我的kaggle
https://www.kaggle.com/code/jiyuanhai/pix2pix-test-pytorch

CycleGAN

这个厉害👍,我愿称之为最强,克服了pixtopix需要数据集一一对应的缺点
论文地址:https://arxiv.org/pdf/1703.10593.pdf
【推荐同济子豪兄】或者看论文详解:https://www.bilibili.com/video/BV1Ya411a78P?spm_id_from=333.999.0.0&vd_source=66d85dad339b02807124d27ef76332c9
B站也有很多讲的不错的视频。
创新型的提出了循环一致性损失,具体技术不多赘述了,有些复杂。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
import itertools

apples_path = glob.glob(r'trainA\\*.jpg')

# 画图显示
# plt.figure(figsize=(8, 8))
# for i, imh_path in enumerate(apples_path[:4]):
#     img = Image.open(imh_path)
#     np_image = np.array(img)
#     plt.subplot(2, 2, i+1)
#     plt.imshow(np_image)
#     plt.title(str(np_image.shape))
# plt.show()

oranges_path = glob.glob(r'trainB\\*.jpg')

# plt.figure(figsize=(8, 8))
# for i, imh_path in enumerate(oranges_path[:4]):
#     img = Image.open(imh_path)
#     np_image = np.array(img)
#     plt.subplot(2, 2, i+1)
#     plt.imshow(np_image)
#     plt.title(str(np_image.shape))
# plt.show()
apples_test_path = glob.glob(r'trainA\\*.jpg')

#数据集已经处理成了256,不用裁减
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

class AO_Dataset(data.Dataset):
    def __init__(self, img_path):  # 初始化方法
        self.img_path = img_path

    def __getitem__(self, index):
        imgpath = self.img_path[index]
        pil_img = Image.open(imgpath)
        pil_img = transform(pil_img)
        return pil_img

    def __len__(self):
        return len(self.img_path)


apple_dataset = AO_Dataset(apples_path)
orange_dataset = AO_Dataset(oranges_path)
apple_test_dataset = AO_Dataset(apples_test_path)

BATHSIZE = 2
NUMWORKERS = 10

apple_dataloader = dataGAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN

GAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN

GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN

GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN

PyTorch实现简单的生成对抗网络GAN

Pytorch Note45 生成对抗网络(GAN)