首先一如既往地是我们的约定环节:
- MCMC:Markov Chain Monte Carlo,马尔可夫蒙特卡洛算法;
- VAE:Variational AutoEncoder,变分自编码器;
第10章 GAN模型
机器学习和深度学习领域不仅有判别模型,而且有生成模型,获得每个标签所对应的数据分布,GAN(Generative Adversarial Network)即是这样一种判别模型,利用深度学习强大的非线性拟合能力模拟真实的数据分布。
10.1 原理、特点及应用
对条件概率
直接建模得到的模型就是判别模型,对训练数据特征和标签的联合概率
建模而得到的模型就是生成模型。
思考:这里的生成模型,和NLP中,利用深度学习模型生成的声学模型、语言模型对应的分布有什么关系?NLP中使用的是生成模型吗?
常用的生成模型有两大类:基于显式概率密度函数的生成模型和基于隐式概率密度函数的生成模型。
- 基于显式概率密度函数的生成模型首先建立所需拟合的数据分布的概率密度函数,然后用有限的样本不断训练该模型。假设该模型
的目标优化函数为
,最优参数为
。当该目标优化函数可解时,直接用优化算法求解该模型的最优参数;
- 当它不可解时,需要用近似的方法得到上述模型的最优参数。变分近似(Variational Approximations)和马尔可夫蒙特卡洛(Markov Chain Monte Carlo,简称MCMC)变分近似中,我们首先得到原始目标优化函数的一个可解的下限函数,然后通过求解该下限函数的最大值间接得到原始目标优化函数的最大值;
不同于基于显式概率密度函数的生成模型,基于隐式概率密度函数的生成模型可以通过直接采样的方式训练模型参数,而不需要提前对真实分布建立模型。而GAN是一种全新的基于隐式概率密度函数的生成模型。
10.1.1 原理
由两个子墨性——生成器和判别器组成。G的输入为随机噪声向量,输出为与真实数据维度相同的数据。D的输入为真实数据或者G生成的数据,输出为其对输入数据的分类。对抗过程可以表示为:
实际中在每一次迭代中先训练k次判别器
,再训练一次生成器
,具体步骤如下:
- 训练判别器。从噪声输入向量z的分布
中随机采用m个样本组成一组用于输入到生成器的批训练数据
。计算判别器的损失函数在当前的批训练数据上的平均值:
,然后求解
对判别器参数
的导数,用梯度上升法(
注意这里是梯度上升?)更新; - 重复步骤1. k次;
- 训练生成器。从p_z(z)中随机采样得到一组批训练数据
,计算生成器的损失函数在当前的批训练数据上的平均值
,然后求解
对生成器参数
的导数,然后用梯度下降法更新
。
- 重复上述步骤直至收敛;
判别器的优化目标:
为求
的最小值,将
看作一个未知量,令
,即
,可得出
为最优解。当
时,目标函数可以进一步表示为
这里的
表示Jensen-Shanon散度,为了解决KL散度的非对称问题,但是和KL共同的问题在于两个分布距离较远时表达不佳,JS散度在距离较远时为常数,而KL散度无意义。 当散度下降时,
逼近真实数据概率分布。
10.1.2 特点
GAN模型可以逼近任何分布,适合生成非结构化高维数据。
- 和VAE相比,GAN模型没有引入近似条件和额外假设,因此能够保证生成的效果更好,通常比VAE生成的图像更加清晰(VAE生成图像模糊好像是业界的共识);
- 与MCMC相比,GAN的模型驯良不依赖MCMC,计算复杂度不高,训练速度较快,生成数据维度更高;
- 以FVBN(Fully visible belief networks,比如WaveNet也是基于此模型)为代表的生成模型必须串行地生成数据在各个维度上的值,计算速度受限,而GAN模型的训练和推理计算都很容易并行化,从而能够充分利用GPU等高性能计算设备加速。
GAN的缺点:
- 在目标函数中,当生成器生成的数据和真实的数据的分布之间没有交集时,JS散度为常数,此时如果判别器也为最优,就会出现梯度消失现象,继而难以根据目标函数更新生成器模型的参数;
- 由于灵活性太高,GAN经常会出现模式崩溃(mode collapse)问题,即生成器和判别器的参数优化过程停滞不前、无法收敛的问题。
通过WGAN、minibatch GAN等改进解决这些问题。
10.1.3 应用
近年来利用GAN强大的数据生成能力获得更好的训练结果,也开始应用于自然语言处理等领域。
- 图像翻译:包含两个方面,1)抽取文字描述的语义特征,2)根据语义特征生成一副相应的图像。在这样的GAN模型中,生成器的输入包括噪声向量和文字的编码向量,这两个向量同时输入到生成器中用于图像生成,使得图像语义与对应文字的语义产生关联。判别器不仅需要判断图像看起来是否真实,而且需要判断图像是否与文字编码向量互相对应;
- 图像编辑:对图像中指定的某些内容进行语义上的修改。在GAN模型中,通过改变生成器输入向量中的某些维度的值来控制生成器输出图像的局部语义信息。为了使用户可以编辑模型的隐藏向量,进而控制图像的语义信息,我们通常需要将编码器与GAN模型结合使用。在训练完GAN模型后,再训练两个编码器,分别将待编辑的图像转换为一个噪声向量和一个隐藏向量。隐藏向量上的每一个维度的值都表示一些高级语义特性。将这两个向量同时作为GAN生成器的输入,就可以通过修改隐藏向量来控制GAN模型生成器所生成图像的某些局部语义信息;
- 图像超分辨率:高维空间中数据分布非常稀疏,仅靠MSE训练出的模型生成的图像在某些细节处看起来不太真实。GAN对真实数据有更好地拟合;
- 半监督学习:同时利用带标签的数据和不带标签的数据进行模型训练的方法。利用GAN产生逼近真实的数据,将不带标签的数据看作第
类,并将其作为训练数据的一部分。训练数据得到了扩充;
- 基于GAN的强化学习:传统的强化学习大多是单任务系统,智能体只能通过奖励函数一直执行一个单一的、不变的任务,通过GAN为智能体持续地生成难度适宜的多个目标任务。
10.2 GAN模型的改进
GAN的主要问题:生成图像的分辨率不高、学习特征不可控、训练过程不稳定等。
10.2.1 CGAN模型
原始的GAN框架不需要对数据分布做假设,而直接采用多层神经网络模拟数据分布。这种训练方式太过自由,在实际中数据表前或其他辅助数据可被用作训练GAN的约束条件。CGAN在判别器和生成器的输入中分别引入一个条件向量
,用于约束生成图像的某些属性。最优化问题可以从之前的输出的分布改写为新的条件分布。
以手写字体为例,条件向量可以是数字的值、笔迹的宽度等。
10.2.2 LAPGAN模型
LAPGAN(Laplacian Pyramid GAN,拉普拉斯金字塔生成式对抗网络,和高斯金字塔配合使用,高斯金字塔通过高斯滤波和降采样获得不同尺度的图像金字塔,拉普拉斯金字塔表示上采样/高斯滤波后与上一级图像的差,表示降采样过程中丢失的高频信息,由拉普拉斯金字塔可以重构出原始图像)模型。
LAPGAN的原理是在每个尺度下都有一个生成器,用于生成该尺度下的拉普拉斯金字塔图像。例如,在
尺度下,噪声向量
输入到生成器
中得到生成图像
。在
尺度下,生成器
的输入为噪声向量
,以及
升采样后的图像
(与CGAN类似,该图像作为条件变量引导
的输出),根据拉普拉斯金字塔图像和高斯金字塔图像之间的关系,
的输出与
相加后即可得到
尺度下的生成图像
。以此类推直到生成最终图像
(
既另一种形式的CGAN)。
10.2.3 DCGAN模型
(注意!这里的C不是CGAN里的Conditional,而是深度卷积的意思)DCGAN模型的整体架构与原始的GAN模型一致,区别在于Generator和Discriminator的具体实现方面的细节:
- 判别器中的池化层被逮捕长的卷积层取代。这是的CNN能够自动学喜特征途的降采样模式,而非按照固定的最大池化或平均池化做降采样。同时这种降采样方式也使得整个DCGAN模型变为完全可微的(啥意思?应该是指可以全局进行反向传播梯度计算,而LAPGAN的降采样方式将不同生成器的梯度更新割裂开来,无法构成端到端应用),对于训练的稳定性有好处;
- Discriminator(除了输出层之外)和Generator(除了输入层之外)的卷积层都经过BN层做归一化,基于BN层的归一化使得训练精度对于参数初始化不敏感,并且能够加快训练速度
10.2.4 InfoGAN模型
在InfoGAN模型中,在损失函数中加入了输入语义特征向量
和输出图像之间的互信息的相反数,互信息越打,相关性越高,损失越低。然而直接优化新的损失函数意味着需要对隐变量
的后验概率分布
进行采样,InfoGAN中引入变分法,采用一个辅助分布
来近似后验概率分布,简化了计算。实际中可以采用深度学习模型来表示
。
10.2.5 LSGAN模型
上述几种GAN模型目标函数大部分都是基于sigmoid交叉熵,会导致梯度消失的问题,当生成器G生成的数据处于判别器D的分类面正确的一侧,但又离真实数据分布比较远时,很难通过该目标函数计算相应的梯度更新G的模型参数,难以继续优化G。LSGAN(Least Squares GAN)是基于最小二乘法的目标函数,分别优化生成器和判别器的GAN模型。
基于sigmoid交叉熵的目标函数容易出现饱和,从而导致无法继续训练。实验效果表明LSGAN模型生成的图像分辨率更高,训练过程更加稳定。
10.2.6 WGAN模型
将JS散度替换为EMD(Earth Mover's Distance,也就是Wasserstein距离)来定义生成数据和真实数据分布之间的差异,相比于JS散度,它可以做到处处连续且可微。
10.3 最佳实践
参见这里。
10.4 小结
GAN模型接住了深度学习模型对数据分布强大的建模能力,使得生成逼真的图像等高维数据成为可能。