生成对抗神经网络(Generative Adversarial Nets,GAN)是一种深度学习的框架,它是通过一个相互对抗的过程来完成模型训练的。典型的GAN包含两个部分,一个是生成模型(Generative Model,简称G),另一个是判别模型(Discriminative Model,简称D)。生成模型负责生成与样本分布一致的数据,目标是欺骗判别模型,让判别模型认为生成的数据是真实的;判别模型试图将生成的数据与真实的样本区分开。生成模型与判别模型相互对抗、相互促进,最终生成模型能够生 成以假乱真的数据,判别模型无法区分是生成的数据还是真实的样本,如此一来,就可以利用生 成模型去生成非常逼真的数据。
由于GAN能够生成复杂的高维度数据,因此被广泛应用于学术研究和工程领域。GAN的主要应用包括图像处理、序列数据生成、半监督学习、域自适应(Domain Adaptation)。图像处理是GAN应用最多的领域,包括图像合成、图像转换、图像超分辨率、对象检测、对象变换等;序列数据生成包括音乐生成、语音生成等。
GAN架构
生成模型的输入是低维度的随机噪声(如向量),输出是高维度的张量(如图像或音乐)。判别模型的输入是高维度的张量(如图像或音乐);输出是低维度的张量,如代表输入张量是否来源于真实样本的热向量(one-hot)。在训练阶段,生成模型输出的高维度张量也会输入给判别模 型,由判别模型判断生成的数据是否已经足够像真实的样本数据。模型训练完成之后,在评估阶段就可以通过给生成模型输入低维度的随机噪声,让生成模型输出高维度的张量数据(图像或音乐)。
关于 GAN 的生成和对抗,最早的 GAN 是作者通过警察(判别模型)和造假币者(生成模 型)来举例的。造假币者试图造出非常逼真的假币,警察试图将假币和真币区分开。造假者不断提升造假币的能力,以试图欺骗警察;警察也不断提高自己的辨别能力,将假币尽可能地 识别出来。二者相互对抗、相互竞争,造假者的造假水平和警察的辨别能力都不断地提高,直到最终造假币者能够造出以假乱真的假币,这就是生成对抗的原理。GAN 的架构如图 1-1 所示。
判别模型
判别模型的输入是一个高维度的张量(如图像或音乐),输出是一个低维度的张量,一般是向量(如图像所属类别)。这个转换的过程是典型的降采样(Down Sampling)过程,即将高维度、大尺寸的输入张量逐步转换成低维度、小尺寸的输出张量,最终输出向量的过程。这一降采样的过程与卷积神经网络的过程十分类似。实际上,GAN网络架构中采用卷积神经网络作为判别模型是十分常见的。判别模型的网络架构如图1-2所示。
生成模型
生成模型的输入是一个代表随机噪声的低维度张量,输出是一个代表高维度的张量(如图像 或音乐)。生成模型的转换过程是一个典型的升采样(Up Sampling)的过程,这一过程与反卷积神经网络的操作过程非常类似。实际上,采用反卷积神经网络作为生成模型的情况也是十分常见的。典型的生成模型的网络架构如图1-3所示。
训练方法
由GAN的原理可知,生成模型(G)和判别模型(D)相互对抗,用L (G,D)代表损失函数,其中 判别模型试图最小化误差,生成模型试图最大化误差。最终的误差函数如下:
由式(1-1)可知,生成模型和判别模型对误差都有影响,其中任何一个变动都会导致误差变动。所以,GAN是采用交替训练的方法来训练的,即固定一个模型,训练另外一个模型。
GAN模型的训练过程可以通过以下步骤来完成。
(1)固定生成模型的参数,优化判别模型的参数。首先生成一批样本数据G (z),将它们标记 为生成的样本,然后与真实的样本数据 x(标记为真实样本)一起输入判别模型。由于判别模型 的目标是将二者区分开,因此这是一个典型的分类预测问题,这也是卷积神经网络非常擅长的。实际上,目前主流的 GAN网络的判别模型往往都是卷积神经网络。经过训练,如果判别模型具 备足够的容量,就能够将真实样本与生成的数据区分开,于是可以得到一个判别模型D1。
(2)固定判别模型D1的参数,优化生成模型的参数。生成模型的优化目标是降低判别模型 的准确率,所以应根据判别模型的辨别结果调整生成模型的参数,直到生成模型能够产生让判别模型D1无法区分的生成数据。至此,可以得到一个生成模型G1。
(3)循环执行(1)和(2),交替训练并且升级生成模型和判别模型,每经过一轮训练,就会提高 一些模型的准确率,升级一次模型,最终得到生成模型 G2、G3、G4、⋯、Gn 和对应的判别模型 D2、D3、D4、⋯、Dn。经过以上 n轮训练,不管是生成模型还是判别模型的性能都会得到极大的提升,判别模型能够区分稍有瑕疵的生成数据。为了能够欺骗判别模型,生成模型必须能够生成 几乎没有瑕疵,或者说是能够以假乱真的数据,最终 GAN具备了生成足够逼真的高维度数据的能力。
为什么要学习GAN?
因为GAN功能强大、应用广泛,并且无须限定样本数据分布,就能够生成锐利而清晰的数据。
GAN 的应用场景十分广泛,包括图像生成、图像处理、序列数据生成、半监督学习、域自适应,以及其他相关领域,如医学图像细分(通过图像细分算法精确定位病灶)、隐写术(一种加密技 术,通过将加密信息写入肉眼可视的图像中实现)、持续学习(深度生成重放)。
图像处理是GAN应用最广泛的领域,包括图像生成、图像转换、图像超分辨率、对象检测、对 象变换、视频合成等场景,其中图像生成是 GAN模型的最原始的应用场景。图像转换是指将一 个领域(x)中的图像转换成另一个领域(y)中的图像,如将真人模特的照片转换成动漫卡通人物 的角色;图像超分辨率是指将低分辨率的图像转换成高分辨率图像的场景;对象检测是指检测图 像中是否包含指定的对象(如图像中是否包含狗);对象变换是指将图像中的对象替换成其他对 象,并且在不改变对象背景的前提下,让变换后的图像看起来足够真实;视频合成是指根据当前 视频的内容,预测未来一段时间的视频内容。
序列数据生成是指生成序列化数据的场景,包括语音对话或音乐合成。
半监督学习是指样本数据中只有少量的样本是有标记的,大量的样本数据是没有标记的,这 种类型的数据在生活中广泛存在。GAN能够通过充分利用标记的样本数据所属类别的信息,从理论角度来说,GAN的识别准确率可以达到非常高。
域自适应是迁移学习的一种,是指将在一个领域学习得到的模型应用在另一个领域中,其应用也十分广泛。例如,根据黄种人的人脸数据集训练一个人脸识别模型,如果该模型直接应用于 非黄种人的人脸(如白种人或黑种人)识别,那么识别的准确率可能会很低。域自适应能够提高 模型的适应能力,保证模型在应用于新领域时的性能。
以上内容来自《GAN生成对抗神经网络原理与实践》,还想学习更多 GAN 的知识?
机会来了~
在评论区留言你对 GAN 或者 AI 学习的看法
AI科技大本营将选出三名优质留言
携手【北京大学出版社】送出
《GAN生成对抗神经网络原理与实践》一本
截至6月18日14:00点
更多精彩推荐后疫情时代,RTC期待新的场景大爆发
Python + 爬虫:可视化大屏帮你选粽子
二次元会让人脸识别失效吗?
点分享点收藏点点赞点在看