深度学习系列39:Imagen模型

Posted IE06

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习系列39:Imagen模型相关的知识,希望对你有一定的参考价值。

Dalle2的SOTA被google家的Imagen模型给破了。

1. 模型介绍

模型相当简单,使用了一个文字转图片的diffusion模型,然后使用了2个超分diffusion模型:

2. 安装与训练

安装:pip install imagen-pytorch

2.1 构建模型

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    beta_schedules = ('cosine', 'linear'),
    timesteps = 1000,
    cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
text_masks = torch.ones(4, 256).bool().cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, text_embeds = text_embeds, text_masks = text_masks, unet_number = i)
    loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 2.)

images.shape # (3, 3, 256, 256)

以上是关于深度学习系列39:Imagen模型的主要内容,如果未能解决你的问题,请参考以下文章

Keras深度学习实战(39)——音乐音频分类

24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)

深度学习系列卷积神经网络模型(ResNetResNeXtDenseNetDenceUnet)

深度学习系列分割网络模型(FCNUnetUnet++SegNetRefineNet)

深度学习系列经典博客收藏

深度学习系列42:多模态ruDalle生成模型