GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN
Posted 挂科难
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN相关的知识,希望对你有一定的参考价值。
pixtopix(像素到像素)
原文连接:https://arxiv.org/pdf/1611.07004.pdf
输入一个域的图片转换为另一个域的图片(白天照片转成黑夜)
如下图,输入标记图片,输出真实图片缺点就是训练集两个域的图片要一一对应,所以叫pixtopix,
网络结构有点复杂,用到了语义分割的UNET网络结构
数据集:
地址忘了,也是官方的,想起来补
代码:这里是建筑物labels to facade的例子
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
# jpg是原始图片
images_path = glob.glob(r'base\\*.jpg')
annos_path = glob.glob(r'base\\*.png')
# png是分割的图片
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.Normalize(0.5, 0.5)
])
class CMP_dataset(data.Dataset):
def __init__(self, imgs_path, annos_path):
self.imgs_path = imgs_path
self.annos_path = annos_path
def __getitem__(self, item):
img_path = self.imgs_path[item]
anno_path = self.annos_path[item]
pil_img = Image.open(img_path)
pil_img = transform(pil_img)
anno_img = Image.open(anno_path)
anno_img = anno_img.convert('RGB')
pil_anno = transform(anno_img)
return pil_anno, pil_img
def __len__(self):
return len(self.imgs_path)
dataset = CMP_dataset(images_path, annos_path)
batchsize = 32
dataloader = data.DataLoader(dataset,
batch_size=batchsize,
shuffle=True)
annos_batch, images_batch = next(iter(dataloader))
for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
anno = (anno.permute(1, 2, 0).numpy()+1)/2
img = (img.permute(1, 2, 0).numpy()+1)/2
plt.subplot(3, 2, i*2+1)
plt.title('input_img')
plt.imshow(anno)
plt.subplot(3, 2, i*2+2)
plt.title('output_img')
plt.imshow(img)
plt.show()
# 定义下采样模块
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 2, 1),
nn.LeakyReLU(inplace=True)
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x, is_bn=True):
x = self.conv_relu(x)
if is_bn:
x = self.bn(x)
return x
# 定义上采样模块
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
output_padding=1),
nn.LeakyReLU(inplace=True)
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x, is_drop=False):
x = self.upconv_relu(x)
x = self.bn(x)
if is_drop:
x = F.dropout2d(x)
return x
# 定义生成器,包含6个下采样,5上采样,1输出
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = Downsample(3, 64) # 64,128,128
self.down2 = Downsample(64, 128) # 128,64,64
self.down3 = Downsample(128, 256) # 256,32,32
self.down4 = Downsample(256, 512) # 512, 16,16
self.down5 = Downsample(512, 512) # 512,8,8
self.down6 = Downsample(512, 512) # 512, 4,4
self.up1 = Upsample(512, 512) # 512 ,8,8
self.up2 = Upsample(1024, 512) # 512, 16,16
self.up3 = Upsample(1024, 256) # 256, 32,32
self.up4 = Upsample(512, 128) # 128,64,64
self.up5 = Upsample(256, 64) # 64,128,128
self.last = nn.ConvTranspose2d(128, 3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1)
def forward(self,x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x6 = self.down6(x5)
x6 = self.up1(x6, is_drop=True)
x6 = torch.cat([x6, x5], dim=1)
x6 = self.up2(x6, is_drop=True)
x6 = torch.cat([x6, x4], dim=1)
x6 = self.up3(x6, is_drop=True)
x6 = torch.cat([x6, x3], dim=1)
x6 = self.up4(x6, is_drop=True)
x6 = torch.cat([x6, x2], dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x6, x1], dim=1)
x6 = torch.tanh(self.last(x6))
return x6
# 定义判别器 输入anno + img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.down1 = Downsample(6, 64) # 64*128*128
self.down2 = Downsample(64, 128) # 128*64*64
self.conv1 = nn.Conv2d(128, 256, 3)
self.bn1 = nn.BatchNorm2d(256)
self.conv2 = nn.Conv2d(256, 1, 3)
def forward(self, anno, img):
x = torch.cat([anno, img], axis=1) # batch*6*h*w
x = self.down1(x, is_bn=False)
x = self.down2(x)
x = F.dropout2d(self.bn1(F.leaky_relu(self.conv1(x))))
x = torch.sigmoid(self.conv2(x)) # batch*1* 60*60
return x
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
print('using cuda:', torch.cuda.get_device_name(0))
else:
print(device)
Gen = Generator().to(device)
Dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(Dis.parameters(), lr=1e-3, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(Gen.parameters(), lr=1e-3, betas=(0.5, 0.999))
# loss
# cgan损失
loss_fn = torch.nn.BCELoss()
# L1-loss 后面计算,求差绝对值的求和
# 绘图
def generator_images(model, test_anno, test_real):
prediction = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy()
test_anno = test_anno.permute(0, 2, 3, 1).detach().cpu().numpy()
test_real = test_real.permute(0, 2, 3, 1).detach().cpu().numpy()
plt.figure(figsize=(10, 10))
display_list = [test_anno[0], test_real[0], prediction[0]]
title = ['input', 'ground truth', 'output']
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
plt.imshow(display_list[i])
plt.axis('off')
plt.show()
# 加载extend为测试
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')
test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
test_daloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batchsize
)
# 返回一个批次
annos_batch, images_batch = next(iter(dataloader))
plt.figure(figsize=(6, 10))
for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
anno = (anno.permute(1, 2, 0).numpy()+1)/2
img = (img.permute(1, 2, 0).numpy()+1)/2
plt.subplot(3, 2, i*2+1)
plt.title('input_img')
plt.imshow(anno)
plt.subplot(3, 2, i*2+2)
plt.title('output_img')
plt.imshow(img)
plt.show()
annos_batch, images_batch = annos_batch.to(device), images_batch.to(device)
LAMBDA = 7 # L1损失权重
D_loss = []
G_loss = []
for epoch in range(300):
D_epoch_loss = 0
G_epoch_loss = 0
count = len(dataloader)
for step, (annos, imgs) in enumerate(dataloader):
imgs = imgs.to(device)
annos = annos.to(device)
d_optimizer.zero_grad()
disc_real_output = Dis(annos, imgs) # 输入真实成对图片
d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output,
device=device)
)
d_real_loss.backward()
gen_output = Gen(annos)
dis_gen_output = Dis(annos, gen_output.detach())
d_fake_loss = loss_fn(dis_gen_output, torch.zeros_like(dis_gen_output,
device=device)
)
d_fake_loss.backward()
disc_loss = d_real_loss + d_fake_loss
d_optimizer.step()
disc_gen_out = Dis(annos, gen_output)
gen_loss_crossentropyloss = loss_fn(disc_gen_out,
torch.ones_like(disc_gen_out,
device=device)
)
gen_l1_loss = torch.mean(torch.abs(gen_output - imgs))
gen_loss = LAMBDA * gen_l1_loss + gen_loss_crossentropyloss
gen_loss.backward()
g_optimizer.step()
with torch.no_grad():
D_epoch_loss += disc_loss.item()
G_epoch_loss += gen_loss.item()
with torch.no_grad():
D_epoch_loss /= count
G_epoch_loss /= count
D_loss.append(D_epoch_loss)
G_loss.append(G_epoch_loss)
print('Epoch', epoch)
generator_images(Gen, annos_batch, images_batch)
给动漫素描自动上色的(AI上色)移步我的kaggle
https://www.kaggle.com/code/jiyuanhai/pix2pix-test-pytorch
CycleGAN
这个厉害👍,我愿称之为最强,克服了pixtopix需要数据集一一对应的缺点
论文地址:https://arxiv.org/pdf/1703.10593.pdf
【推荐同济子豪兄】或者看论文详解:https://www.bilibili.com/video/BV1Ya411a78P?spm_id_from=333.999.0.0&vd_source=66d85dad339b02807124d27ef76332c9
B站也有很多讲的不错的视频。
创新型的提出了循环一致性损失,具体技术不多赘述了,有些复杂。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
import itertools
apples_path = glob.glob(r'E:\\深度之眼深度学习\\GAN生成对抗网络实战(PyTorch版)\\00、代码+课件+数据\\cyclegan-pytorch参考代码—日月光华\\data\\trainA\\*.jpg')
# 画图显示
# plt.figure(figsize=(8, 8))
# for i, imh_path in enumerate(apples_path[:4]):
# img = Image.open(imh_path)
# np_image = np.array(img)
# plt.subplot(2, 2, i+1)
# plt.imshow(np_image)
# plt.title(str(np_image.shape))
# plt.show()
oranges_path = glob.glob(r'E:\\深度之眼深度学习\\GAN生成对抗网络实战(PyTorch版)\\00、代码+课件+数据\\cyclegan-pytorch参考代码—日月光华\\data\\trainB\\*.jpg')
# plt.figure(figsize=(8, 8))
# for i, imh_path in enumerate(oranges_path[:4]):
# img = Image.open(imh_path)
# np_image = np.array(img)
# plt.subplot(2, 2, i+1)
# plt.imshow(np_image)
# plt.title(str(np_image.shape))
# plt.show()
apples_test_path = glob.glob(r'E:\\深度之眼深度学习\\GAN生成对抗网络实战(PyTorch版)\\00、代码+课件+数据\\cyclegan-pytorch参考代码—日月光华\\data\\trainA\\*.jpg')
#数据集已经处理成了256,不用裁减
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
class AO_Dataset(data.Dataset):
def __init__(self, img_path): # 初始化方法
self.img_path = img_path
def __getitem__(self, index):
imgpath = self.img_path[index]
pil_img = Image.open(imgpath)
pil_img = transform(pil_img)
return pil_img
def __len__(self):
return len(self.img_path)
apple_dataset = AO_Dataset(apples_path)
orange_dataset = AO_Dataset(oranges_pathGAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN
GAN-生成对抗神经网络(Pytorch)-合集GAN-DCGAN-CGAN
GAN-生成对抗网络(Pytorch)合集--pixtopix-CycleGAN