DALL-E是OpenAI的多模态与训练模型,有120亿参数,在2.5亿图像文本对上寻来呢完成,主要用于文生图后续有DALL-E2和DALL-E3,其中DALL-E和DALL-E2是可以找到源代码和权重的,后面的DALL-E3是闭源的。学习原理的话我们从DALL-E入手。
从GPT开始,zero-shot的想法开始成为自然语言处理的主流,而视觉领域能否zero-shot呢(训练一个大模型,后续使用不需要微调)?CLIP和DALL-E告诉我们视觉也可以zero-shot,所以文章的Tile是Zero-Shot Image Generation,DALL-E的名字灵感来自于著名画家Salvador Dali和墨西哥的墙Wall-E。
参考链接:
Paper:https://arxiv.org/pdf/2102.12092.pdf
Code: https://github.com/lucidrains/DALLE-pytorch
(DALL-E2: https://github.com/lucidrains/dalle2-pytorch)
OpenAI Blog:https://openai.com/research/dall-e
(DALL-E2:https://openai.com/dall-e-2)
(DALL-E3:https://openai.com/dall-e-3)
Explain Video: https://www.youtube.com/watch?v=j4xgkjWlfL4
知乎:
https://zhuanlan.zhihu.com/p/625975291
https://zhuanlan.zhihu.com/p/604902250
结构
DALL-E是基于变分自编码器(dVAE)和Transformer架构的生成模型。它是一个两阶段的模型:它的第一个阶段是离散变分自编码器,用于生成图像token,它的第二个阶段时混合了图像和文本特征的,1以transformer为基础的生成模型。
它首先使用一个VAE将图像编码为离散的潜在表示,然后使用一个大型Transformer模型学习从自然语言描述道这些离散潜在表示的映射。训练完成后,DALL-E可以根据输入的文本描述生成一组与描述相符的图像。
dVAE
dVAE的作用是将图像信息压缩,减少输入size(提取为特征token图),方便与语音token一同处理。DALL-E把文本token和图像token当成一个数据序列,通过Transformer进行自回归。如果把一个pixel当成一个token来处理,会导致计算过于庞大,于是DALL-E引入dVAE模型来提取图像信息导导编码之后的特征token表示。
整体流程
第一阶段,先训练一个dVAE把每张256×256的RGB图片压缩成32×32的图片token,每个位置有8192种可能的取值(dVAE encoder输出的toker取值唯独为32x32x8192 logits),然后通过logits索引codebook进行特征组合,codebook的embedding是可学习的。
第二阶段,用BPE Encoder对文本进行编码,得到最多256个文本token,token数不满256的话padding到256,然后将256个文本token和1024个图像token进行拼接,得到长度为1280的数据,最后将拼接的数据输入Transformer中进行自回归训练,典型的teacher forcing 划窗方式进行生成
训练阶段,先训练dVAE模型,然后固定dVAE模型再来训练自回归的Transformer
推理阶段,使用BPE编码器,可以根据文本输入生成很多图像code串,然后使用dVAE的decoder生成很多可选的256×256大小的图像
rerank阶段,通过输入不同的首个图像的token可以生成很多各种类型的图片,需要根据CLIP来对得到的图文进行重排
以上流程中,dVAE、Transformer和CLIP三个模型都是不同阶段独立训练的。
关键算法
目标函数:
DALL-E的目标函数是最小化损失函数L,其中L是重构损失 L_r 和 KL散度Loss 的加权和
L=L_r+\beta L_{kl}
重构损失衡量DALL-E生成图像和原始图像之间的差异,由MSE(均方误差)或者NLL(负对数似然)来表达,其中xi是原始图像,G(zi,ci)是由编码器生成的图像,zi是隐变量,ci是输入文本的编码表示,N是图像数量。
L_r=\frac{1}{N}\sum_{i=1}^N||x_i-G(z_i,c_i)||^2
KL散度Loss,度量隐变量分布与标准正态分布之间的差异,以鼓励模型学习更有意义的隐空间表示。
L_{kl}=\frac{1}{2}\sum_{i=1}^N(\mu_i^2+\sigma_i^2-log(\sigma_i^2)-1)