深度学习系列42:多模态ruDalle生成模型

Posted IE06

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习系列42:多模态ruDalle生成模型相关的知识,希望对你有一定的参考价值。

1. ruDALL-E介绍

俄国实现的DALLE-E模型,ru是russia的简写,要使用俄文输入文本,不过用起来确实是最方便的Dalle开源代码。安装:

pip install rudalle==1.0.0
pip install ruclip==0.0.1rc7

目前已经可用的有4个:
ruDALL-E Malevich XL和Malevich_v2 XL:中型的图片模型
ruDALLE Surrealist XL:超现实主义图片模型
ruDALL-E Emojich XL:生成emoji图片模型
ruDALL-E Kandinsky XXL:大型的图片模型。

2. 文本生成图像

如果用cpu的话,fp16=False

import ruclip
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_ruclip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan
from rudalle.utils import seed_everything

# prepare models:
device = 'cuda'
dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae(dwt=True).to(device)
realesrgan = get_realesrgan('x2', device=device)
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=device)
clip_predictor = ruclip.Predictor(clip, processor, device, bs=8)
text = 'радуга на фоне ночного города' # 修改这里
seed_everything(42)
pil_images = []
scores = []
for top_k, top_p, images_num in [(2048, 0.995, 24),]:
    _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, bs=8, top_p=top_p)
    pil_images += _pil_images
    scores += _scores

show(pil_images, 6)

当然还有一种方式,就是调用翻译:
pip install translators==4.9.5

text = translators.google(text, from_language='en', to_language='ru')

挑选图片代码:

top_images, clip_scores = cherry_pick_by_ruclip(pil_images, text, clip_predictor, count=6)
show(top_images, 3)

超分代码:

sr_images = super_resolution(top_images, realesrgan)
show(sr_images, 3)

做了一些测试,问题也比较明显:

  1. 人像不尽完美,四肢和人脸总有些扭曲
  2. 会自动生成水印

3. 图像prompt

用如下4张图为prompt,来生成新的图像。

代码为:

from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle.image_prompts import ImagePrompts
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle.utils import seed_everything
device = 'cuda'
dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
realesrgan = get_realesrgan('x4', device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)
ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5')
ruclip = ruclip.to(device)

import requests
from PIL import Image
import torch

red_sky_url = 'https://azur.ru/data/newfotos1/big/8/97908.jpg'
sunny_sky_url = 'https://99px.ru/sstorage/53/2016/04/mid_162114_1196.png'
cloudy_sky_url = 'https://vesti-lipetsk.ru/images/news/2020/01/18/bxcz2tot4t0.jpg'
night_sky_url = 'https://i.pinimg.com/originals/3c/51/82/3c5182ee0773333a3e0dd67d5ac41598.jpg'

red_sky = Image.open(requests.get(red_sky_url, stream=True).raw).resize((256, 256))
sunny_sky = Image.open(requests.get(sunny_sky_url, stream=True).raw).resize((256, 256))
cloudy_sky = Image.open(requests.get(cloudy_sky_url, stream=True).raw).resize((256, 256))
night_sky = Image.open(requests.get(night_sky_url, stream=True).raw).resize((256, 256))

skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
borders = 'up': 4, 'left': 0, 'right': 0, 'down': 0
image_prompts = [ImagePrompts(sky, borders, vae, torch.device('cuda'), crop_first=True) for sky in skyes]
text = 'Храм Василия Блаженного'
all_skyes_images = []

for image_prompt in image_prompts:
    seed_everything(42)
    pil_images = []
    for top_k, top_p, images_num in [
        (2048, 0.995, 3),
        (1536, 0.99, 3),
        (1024, 0.99, 3),
        (1024, 0.98, 3),
        (512, 0.97, 3),
        (384, 0.96, 3),
        (256, 0.95, 3),
        (128, 0.95, 3), 
    ]:
        _pil_images, _ = generate_images(
            text,
            tokenizer,
            dalle,
            vae,
            top_k=top_k,
            images_num=images_num,
            image_prompts=image_prompt,
            top_p=top_p,
            use_cache=False
        )
        pil_images += _pil_images
    top_images, _ = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=5)
    all_skyes_images += super_resolution(top_images, realesrgan)
    
show(all_skyes_images, 5)

效果为:

4. 宽屏图片生成

下载代码:git clone https://github.com/shonenkov-AI/rudalle-aspect-ratio
使用方法:

import sys
sys.path.insert(0, './rudalle-aspect-ratio')
from rudalle_aspect_ratio import RuDalleAspectRatio, get_rudalle_model
from rudalle import get_vae, get_tokenizer
from rudalle.pipelines import show

device = 'cuda'
dalle = get_rudalle_model('Surrealist_XL', fp16=True, device=device)
vae, tokenizer = get_vae().to(device), get_tokenizer()
rudalle_ar = RuDalleAspectRatio(
    dalle=dalle, vae=vae, tokenizer=tokenizer,
    aspect_ratio=32/9, bs=4, device=device
)
_, result_pil_images = rudalle_ar.generate_images('готический квартал', 1024, 0.975, 4)
show(result_pil_images, 1)

效果图:

竖屏的话改下比例即可:

rudalle_ar = RuDalleAspectRatio(
    dalle=dalle, vae=vae, tokenizer=tokenizer,
    aspect_ratio=9/32, bs=4, device=device
)
_, result_pil_images = rudalle_ar.generate_images('голубой цветок', 512, 0.975, 4)
show(result_pil_images, 4)

效果如下:

5. Emoji图片

emoji例子,把模型名称改成Emojich:

from rudalle.pipelines import generate_images, show
from rudalle import get_rudalle_model, get_tokenizer, get_vae
from rudalle.utils import seed_everything

device = 'cuda'
dalle = get_rudalle_model('Emojich', pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae(dwt=True).to(device)

text = 'Дональд Трамп из лего'  # Donald Trump made of LEGO

seed_everything(42)
pil_images = []
for top_k, top_p, images_num in [
    (2048, 0.995, 16),
]:
    pil_images += generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p, bs=8)[0]

show(pil_images, 4)

一些生成例子效果如下:

以上是关于深度学习系列42:多模态ruDalle生成模型的主要内容,如果未能解决你的问题,请参考以下文章

多模态深度学习方法综述

深度学习系列30:BART模型

多模态融合综述

综述 | 浙江大学多模态深度学习

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

清华大学刘洋--基于深度学习的机器翻译--- 低资源/多模态