深度学习系列41:多模态Dalle-min生成图像
Posted IE06
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习系列41:多模态Dalle-min生成图像相关的知识,希望对你有一定的参考价值。
1. dalle-min模型介绍
参考https://huggingface.co/flax-community/dalle-mini,可以用这个版本进行探索和学习。
dalle模型包括:
- 一个基于BART的编码器,将文本token转为图像token
- 一个基于VQGAN模型的编解码器,将图像token和图片之间互相转换
首先要训练VAGAN模型。开源的模型对于人脸重构效果不佳,期待有人做优化训练;此外还需要一个预训练好的BART模型。
训练模型包括如下几个部分:
1)将图片用VQGAN的编码器转为图像token
2)将文字用BART的编码器转为文字token
3)两者拼接后用BART的解码器转为图像toke
4)与第一步的图像token计算交叉熵,进行优化
2. 推理过程
使用训练好的BART模型将文字转为图片token,然后用训练的VQGAN模型解码器生成图片,然后用CLIP模型挑出最优的K张图片。
3. 如何使用
这里是在线测试地址:https://huggingface.co/spaces/dalle-mini/dalle-mini
这里是git地址:https://github.com/borisdayma/dalle-mini
首先要安装库:
pip install -q dalle-mini jax
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
然后执行下面的代码
import jax
import jax.numpy as jnp
# dalle模型,资源不够的话可以用下面那个
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)
from dalle_mini import DalleBartProcessor
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
prompts = ["sunset over a lake in the mountains", "the Eiffel tower landing on the moon"]
tokenized_prompts = processor(prompts)
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange
print(f"Prompts: prompts\\n")
# generate images
images = []
n_predictions = 8
for i in trange(n_predictions):
# generate images
encoded_images = model.generate(**tokenized_prompts,params=params,condition_scale=10.0)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = vqgan.decode_code(encoded_images, params=vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
display(img)
print()
到这里为止已经可以生成一系列图片了
接下来用clip来评分:
# CLIP model
CLIP_REPO = "openai/clip-vit-base-patch32"
CLIP_COMMIT_ID = None
# Load CLIP
clip, clip_params = FlaxCLIPModel.from_pretrained(
CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
clip_params = replicate(clip_params)
# score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
logits = clip(params=params, **inputs).logits_per_image
return logits
from flax.training.common_utils import shard
# get clip scores
clip_inputs = clip_processor(
text=prompts * jax.device_count(),
images=images,
return_tensors="np",
padding="max_length",
max_length=77,
truncation=True,
).data
logits = p_clip(shard(clip_inputs), clip_params)
# organize scores per prompt
p = len(prompts)
logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()
#logits = rearrange(logits, '1 b p -> p b')
for i, prompt in enumerate(prompts):
print(f"Prompt: prompt\\n")
for idx in logits[i].argsort()[::-1]:
display(images[idx*p+i])
print(f"Score: jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f\\n")
print()
以上是关于深度学习系列41:多模态Dalle-min生成图像的主要内容,如果未能解决你的问题,请参考以下文章