Guidance,让扩散模型的指标更能打

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: Guidance,让扩散模型的指标更能打

扩散模型的悲惨开始

搞AI的都知道,一个模型好不好空口无凭,我们必须要用数据说话。但是扩散模型刚出来的时候存在一个问题,就是图片我们用人眼看起来好像效果还是挺好的,但是在数值上就是打不过GAN生成的图像,所谓的数值上就是在评价指标上,比如FID score和IS score。就算我们直接把图片放到论文中,人家也可能会质疑说,你这些图片肯定是把效果好的放出来了,效果不好的根本就没让我们看吧。所以如果数值提不上去,论文很难有说服力,审稿人那一关就过不了。所以很重要的一个努力方向就是提高模型的指标。另外还有一个问题就是模型的采样是按照时间步来的嘛,这样比较慢。所以大家就在想办法改进这些地方。就引出了我们今天要讲的东西,借助一些技巧或者guidance来帮助引导模型训练和采样。

其实在之前的这篇文章【翻译】最近兴起的扩散模型提过一嘴引导扩散这个事,今天就把它展开讲讲。

从uncondition到condition

我们扩散模型的反向过程是给定$x_T$时刻的输入,从$x_T$生成$x_{T-1}$,这样逐步一直到恢复图像$x_0$。恢复过程每一步用的都是一个U-Net,不停地采样、生成、采样、生成,算算算。

image.png

我们将每一步的过程formulate一下就是$f_\theta(x_t,t)$,就是时间步t的生成需要两个输入信息:

  1. 当前时间步的输入$x_t$
  2. 当前时间步$t$

引入guidance之后给模型提供指导,这个指导我们用$\mathbf y$表示,那我们就可以得到修改之后的网络$f_\theta(x_t,t,\mathbf y)$。接下来我们就来看一下$\mathbf y$究竟可以变成什么。


classifier guided diffusion

image.png

classifier gudied diffusion 就是在我们训练模型的同时 再 额外训练一个图像分类器,在很多论文中的实现方法就是直接用ImageNet的数据集去训练一个图像分类器,由于扩散模型的特性是从每一步的噪声图像中恢复,所以训练分类器的过程中是不断对ImageNet的图像加噪之后再训练分类器。

这个分类器的作用是:当我们拿到图片$x_t$之后,我们可以知道它分类的对不对,通过使用交叉熵目标函数,会得到一些梯度,然后我们用这些梯度去帮助模型进行接下来的采样和生成。

分类器的梯度暗含当前图片的一些信息,比如是否含有某个物体,比如图片是否真实等 。这个梯度引导就是给U-Net传递信息,我现在生成的图片要像什么东西。

经过classifier gudied diffusion的引导之后,生成的图片保真性就提高了很多,在FID score和IS score等指标上大幅度提高,扩散模型也第一次在评价指标的数值上超越了big GAN。详细的可以看《diffusion model beat GAN》这篇论文。这个做法是牺牲了一部分的多样性来换取图片的真实性,但是这个取舍是值得的,它的多样性依旧是比GAN要好很多的。评价指标数值提上来了,多样性还能吊打你,这样就开始奠定了新一位大魔王的地位了。

GAN、VAE、Diffusion model等生成模型如果大家不了解可以浅浅看一下这个文章:图像生成模型简介 - 掘金 (juejin.cn)

简单classifier之后的思路

除了用分类器我们还能借助什么产生指导信号呢?比如换成CLIP模型,这样文本和图像还能联系起来,这样我们就不止可以使用梯度来引导图像生成了,还可以使用文本信息来引导图像生成。

还有用语言模型进行文本方向引导的,利用图片的一些小任务引导的等等。


通过上述两大类改进方法,我们可以归结起来引导扩散模型生成和采样的一大方向就是:

$$ p(x_{t-1}|x_t) = \| \epsilon - f_\theta(x_t,t,\mathbf y) \| $$

其中的$\mathbf y$就是我们选定的控制方法。

但是这个方法有一个缺陷,就是我们需要借助另外一个模型进行引导。简单的分类器都要用带噪声的ImageNet数据集进行训练,更别说CLIP或者其他预训练的大型语言模型了。这样成本比较高,训练过程也是不可控的。


classifier-free guidance

因为我们上边提到的一些缺陷,所以有研究人员开始考虑不需要额外模型的方法,也就是classifier-free。不使用classifier之后我们能不能找到一种指导信号去让模型生成的更好呢?

在模型训练阶段让其产生两个输出,一个是在有条件的情况下产生的输出,一个是在无条件下产生的输出。

$$ f_\theta(x_t,x,\mathbf y) - f_\theta(x_t,x,\phi) $$

比如你用文本控制图像生成,训练时候用的是图像文本对,文本作为guidance信号。此时$\mathbf y$就是文本,你在训练的时候用$\mathbf y$去生成图像$f_\theta(x_t,x,\mathbf y)$,然后在某些情况下你随机去掉这个信号,取而代之传入一个空集$\phi$,再去生成另外一个输出$f_\theta(x_t,x,\phi)$,这样你就可以在生成图像的分布空间中知道有条件和无条件图像的距离,我们就可以知道在这个分布空间上如何从无条件输出得到有条件的输出,通过训练我们就可以知道有条件、无条件的差距是多少,最后去做图像生成的时候,即使我们没有条件去做生成,也可以获得一个比较合理的输出结果。这样就摆脱了分类器。

但是这样存在一个问题就是训练代价比较昂贵,扩散模型的训练本来就很烧钱,现在居然要一次做两个输出,一个有条件一个无条件,这样又增加很多成本。

但是这个真的是很好用的方法,GLIDE、DALL·E 2、Imagen等模型都用到了这个技巧。

相关文章
|
7月前
|
机器学习/深度学习 数据采集 搜索推荐
多模型DCA曲线:如何展现和解读乳腺癌风险评估模型的多样性和鲁棒性?
多模型DCA曲线:如何展现和解读乳腺癌风险评估模型的多样性和鲁棒性?
165 1
|
机器学习/深度学习
评分是机器学习领域中的一种评估模型性能的指标
评分是机器学习领域中的一种评估模型性能的指标
98 1
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
LLM-Mixer: 融合多尺度时间序列分解与预训练模型,可以精准捕捉短期波动与长期趋势
近年来,大型语言模型(LLMs)在自然语言处理领域取得显著进展,研究人员开始探索将其应用于时间序列预测。Jin等人提出了LLM-Mixer框架,通过多尺度时间序列分解和预训练的LLMs,有效捕捉时间序列数据中的短期波动和长期趋势,提高了预测精度。实验结果显示,LLM-Mixer在多个基准数据集上优于现有方法,展示了其在时间序列预测任务中的巨大潜力。
74 3
LLM-Mixer: 融合多尺度时间序列分解与预训练模型,可以精准捕捉短期波动与长期趋势
|
3月前
|
机器学习/深度学习 自然语言处理 并行计算
扩散模型
本文详细介绍了扩散模型(Diffusion Models, DM),一种在计算机视觉和自然语言处理等领域取得显著进展的生成模型。文章分为四部分:基本原理、处理过程、应用和代码实战。首先,阐述了扩散模型的两个核心过程:前向扩散(加噪)和逆向扩散(去噪)。接着,介绍了训练和生成的具体步骤。最后,展示了模型在图像生成、视频生成和自然语言处理等领域的广泛应用,并提供了一个基于Python和PyTorch的代码示例,帮助读者快速入门。
|
7月前
|
机器学习/深度学习 人工智能 算法
社交网络分析4(上):社交网络链路预测分析、Logistic回归模型、LLSLP方法(LightGBM 堆叠链路预测)、正则化方法、多重共线性
社交网络分析4(上):社交网络链路预测分析、Logistic回归模型、LLSLP方法(LightGBM 堆叠链路预测)、正则化方法、多重共线性
507 0
社交网络分析4(上):社交网络链路预测分析、Logistic回归模型、LLSLP方法(LightGBM 堆叠链路预测)、正则化方法、多重共线性
|
5月前
|
机器学习/深度学习 Serverless Python
`sklearn.metrics`是scikit-learn库中用于评估机器学习模型性能的模块。它提供了多种评估指标,如准确率、精确率、召回率、F1分数、混淆矩阵等。这些指标可以帮助我们了解模型的性能,以便进行模型选择和调优。
`sklearn.metrics`是scikit-learn库中用于评估机器学习模型性能的模块。它提供了多种评估指标,如准确率、精确率、召回率、F1分数、混淆矩阵等。这些指标可以帮助我们了解模型的性能,以便进行模型选择和调优。
|
7月前
|
机器学习/深度学习 人工智能
【机器学习】有哪些指标,可以检查回归模型是否良好地拟合了数据?
【5月更文挑战第16天】【机器学习】有哪些指标,可以检查回归模型是否良好地拟合了数据?
|
7月前
|
搜索推荐 机器人 开发者
视频扩散模型
视频扩散模型【2月更文挑战第26天】
41 1
|
7月前
|
数据采集 编解码
Sora:一个具有灵活采样维度的扩散变压器
Sora:一个具有灵活采样维度的扩散变压器
76 9
|
7月前
|
机器学习/深度学习 存储 编解码
重参架构的量化问题解决了 | 粗+细粒度权重划分量化让RepVGG-A1仅损失0.3%准确性
重参架构的量化问题解决了 | 粗+细粒度权重划分量化让RepVGG-A1仅损失0.3%准确性
88 0
重参架构的量化问题解决了 | 粗+细粒度权重划分量化让RepVGG-A1仅损失0.3%准确性

热门文章

最新文章