一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

无论你是哪一类人,只要是接触过深度学习方面的知识,就一定听过GAN的大名。这两天,“GAN之父” Lan Goodfellow 从Google跳槽到苹果的消息,更是为人所津津乐道。其实他已于3月正式开始在苹果上班,加盟CEO库克直接领导的神秘特别项目小组。大家都在拭目以待他是否会在苹果再创奇迹。

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果


自从2014年GAN首篇论文发表以来,它就像是一座高峰一样,只要你的脚踏进深度学习领域,就保准能看到它的身影,当然更多的可能是因为它的发音奇特记住了这个名字。

然而,GAN确实是近十年以来深度学习领域乃至机器学习领域最富有创意的想法之一,它所能做到的事情也十分的令人不可思议,例如利用色块还原图片的黑科技:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

下面就让我们来彻底地了解一下它究竟是个什么东西。


GAN的思想:阴阳互济

要说这GAN的思想,其实相当简单。首先我们需要知道GAN它是拿来做什么的?

GAN的目标就是赋予机器类似于想象的天赋。其根本作用就是生成数据。

也就是说,我们利用GAN可以去生成我们需要的数据,这对整个深度学习领域来说简直是一个大杀器。很多时候就是因为数据不够我们才会提出更加复杂的算法和模型以期望能够在比较小的数据集上训练出好的效果,可反过来越是复杂的模型可能对数据量的要求反而越大,这就让人很难办了。

然而GAN却是像用深度神经网络,去生成数据。这在之前有人做吗?有,而且很多,包括什么变分自编码器之类的,但那些都是偏向传统机器学习的方法,它的数学味道很重。然而GAN不一样,身为新时代(深度学习时代)下的新算法,它完美地继承了深度学习黑盒(就是你根本不知道为什么它能取得这么好的效果,但是它就是做到了)的特点,在对抗中完成了数据生成工作。

说了这么多,其实GAN的主要思想就是,我们用一个生成模型生成数据,用一个鉴别模型鉴别数据,生成模型要尽可能地生成和真实数据相近的数据,从而欺骗鉴别模型,让它分不出来哪边是真哪边是假

它的流程就像下面这样:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

这就像是阴阳的两极,它们既相互对抗,又相辅相成。

因为在训练的过程中,

  • 我们利用生成网络的数据和真实数据来训练鉴别网络,让它鉴别能力更强;
  • 可下一个时刻,我们就会拿鉴别网络区训练生成网络,让生成网络的欺骗性更高。

于是二者就在这样的相互对抗中共同成长。


GAN的本质:JS散度

接下来我们要说一些硬核的东西,关于GAN的数学性的证明。虽然前文我们说GAN完美沿袭了深度学习的优良传统,但这仅限于它的生成网络和鉴别网络,而整体的对抗过程却是有严格的数学证明的,接下来就让我们来看看它的真面目吧。

根据论文,我们可以知道GAN的目标函数其实就是一个关于极小化极大的二人博弈问题:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

根据上文的分析我们可以知道,整个GAN架构最重要的目的是为了能够在对抗训练的过程中使得生成器G生成的样本分布能够接近真实样本的分布。

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果


最终可以证明,GANs的最小化问题,如下所示:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果


其中,

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

整理公式可以得到:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果


即:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果


上面所有的证明换成一句话就是,GAN的最优化过程其实就是在寻找生成数分布与真实数据分布的最小JS散度。


GAN的实现:简单的DCGAN

这一节我们参照书和开源代码上的内容,利用keras简单实现一个深度卷积生成式对抗网络(DCGAN)。其实现过程就如上文所述的一样,我们需要分别实现一个生成网络和一个鉴别网络,然后把它们放在一起训练。

具体来说,其实分成了三部分:

  • 第1部分:生成网络模型
  • 第2部分:鉴别网络模型
  • 第3部分:将它们整合在一起的GAN模型

流程如下:

  1. 从潜在空间中随机抽取一些点,也就是随机噪声;
  2. 利用1的随机噪声用生成网络生成图像;
  3. 将生成图像与真实图像混合;
  4. 使用这些混合后的图像以及相应的标签来训练鉴别网络;
  5. 然后再在随机空间中进行采样抽取新的点;
  6. 用新的点加上全部是真实标签来训练GAN,也就是生成网络和鉴别网络的整合网络。这里鉴别网络将被固定,我们相当于是在告诉生成网络,尽量把生成的图像让鉴别网络识别为真。

接下来我们来看看具体的模型结构实现。

  • 生成网络

首先是生成网络部分,我们接收一个100维度的向量将其映射成(32,32,3)的数据。这其中比较重要的部分就在于上采样的过程和转置卷积。

生成网络的模型结构如下图所示:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

  • 鉴别网络

接下来是关于鉴别网络的结构:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

我们注意到上述鉴别网络的结构中,没有出现池化层,这是因为最大池化运算会导致梯度的稀疏性,而这将妨碍GAN的训练,因此在鉴别网络的下采样过程中,我们不采用池化而采用步进大于1的卷积层进行下采样。

  • 对抗网络的设置

对抗网络将上述的生成网络和鉴别网络整合在了一起,更准确地说它是用来训练生成网络的,其具体代码如下:

discriminator.trainable = False #在训练gan的时候固定鉴别网络
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input,gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr= 0.0004,clipvalue = 1.0,decay=1e-8)
gan.compile(optimizer=gan_optimizer,loss = 'binary_crossentropy')

结果

在这之后开始训练GAN,在这个过程中慢慢地产生一些图片,可以保存到本地进行查看,其最初生成的图片如下所示:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

真实的图片则是这样的:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

可以看到图片慢慢地向着真实图片的样子发展,或者说它至少在色块和形状上与真实图片接近了。

然而GAN的训练是困难的,任何参数的调整和修改甚至都可能使得GAN的训练失败,纳什均衡的理论平衡点就像沧海中的浮萍那样脆弱。

当然,这也是因为网络过于简单的缘故,受制于有限的设备,我们无法复现出比较复杂的GAN网络,只能通过上述简单的例子来感受GAN的具体内容。

最终,上面用来生成青蛙的GAN网络的输出如下所示:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

下面再使用DCGAN来尝试生成手写数字。我们只需要稍微修改生成网络和鉴别网络的部分参数,使得原来适用于32*32*3的网络模型变成适用于28*28*1即可,一开始我们得到的手写字生成图片如下所示:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

我们可以看到,慢慢地生成网络已经能够开始生成手写字的大致轮廓,只是难以形成系统的结构,比较模糊和分散,到训练的后期生成的手写字则已经可以很好地辨认了:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

真实的样本数据如下所示:

一文玩转GAN(对抗生成网络),5岁红爆AI界,其父从Google跳苹果

与上面的生成样本相比真实数据虽然更规整一点,但明显生成样本的手写字也是真实可信的手写字,如果将它们混起来,想来即便是通过肉眼也很难区分。


GANs的变种:各种GAN

通过前文我们已经了解到了GANs的基本思想和训练方法。从原理上说GANs的思路非常巧妙也符合人们的直觉,两个模型就像这世界上的其他事物一样,在相互制约和对抗中演变进化;从理论上说,GANs的原生论文就已经给出了它的最优性证明,也就是上文第二小节的内容,证明了GANs本质上是最小化生成数据的分布与真实数据分布之间的JS散度,当算法收敛的时候,生成器刻画的分布就等于真实数据的分布。

然而通过第三小节我们可以发现,在实际训练的过程中,这个收敛非常地不稳定,往往一些参数的小变动就会导致两个网络训练不平衡,最终一个网络支配了另一个网络使得训练失败。因此在GANs的发展中,衍生除了很多变种结构,想从各个方面来改进GANs的不足之处。

  • WGAN

我们知道在许多科幻小说中,对于三维空间的生物来说,二维空间的生物是不存在的;从数学上也是二维空间的物体是不存在三维空间中的体积这个概念的。迁移到我们当前的问题上来说,就是高维空间中并不是每一个点都能表达一个样本,空间大部分都是多余的,而真实的数据是蜷缩在低维子空间中的,它在高维空间中的表现就是一个曲面。而对于GANs来说是要在高维空间中寻找到这样的分布,使它和真实数据表现在高维空间中的曲面类似。而难点就在于,这个曲面就像一张极其单薄的纸飘在三维空间中一样,根本难以察觉,如何抓住这个低维度的幽灵成为了GANs的一个研究方向。

针对前面的这个问题,WGAN于2017年被提出,它的想法就是不再使用KL散度或者说JS散度来作为目标函数,而是利用一个特殊的距离——Wasserstein距离。

从理论上解释使用W距离更好是很复杂的,我们可以简单理解成,现在的W距离与JS散度比较更加敏感,生成器自身分布稍微变化一下,就能影响到它与真实分布之间的W距离,而这对JS距离来说是比较困难的。因此使用W距离可以更加有效地锁定捕捉到低维子空间中的真实数据分布。

  • DCGAN

回顾上文第三节的内容,我们可以发现在实现GANs生成图片的时候我们就已经使用了DCGAN的结构。但其实DCGAN并不是GANs结构一提出来时就有的。虽然我们现在提到图像就会想到卷积神经网络,但对于GANs来说利用卷积神经网络生成图像还走过了一段曲折的道路。

从上文可以知道,我们需要通过低维度的随机向量生成一张具有更多数据点的图像,也就是信息要从少到多,DCGAN的论文就提出了一些指导性的原则:

1. 去掉一切会损失位置信息的结构,比如池化层

2. 在生成网络中需要有上采样的结构,例如转置卷积。

3. 生成网络的最后不需要全连接层

4. 加入BN和ReLU激活函数

5. 如上文所说鉴别网络需要去掉池化层,替换成步长大于1的卷积层;同时也需要用到BN和LReLU;LReLU相比ReLU,更不容易发生梯度消失。

至此,我们完成了GANs上的旅行,相信各位对GANs一定有了一个全新的了解。


分享到:


相關文章: