一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

無論你是哪一類人,只要是接觸過深度學習方面的知識,就一定聽過GAN的大名。這兩天,“GAN之父” Lan Goodfellow 從Google跳槽到蘋果的消息,更是為人所津津樂道。其實他已於3月正式開始在蘋果上班,加盟CEO庫克直接領導的神秘特別項目小組。大家都在拭目以待他是否會在蘋果再創奇蹟。

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果


自從2014年GAN首篇論文發表以來,它就像是一座高峰一樣,只要你的腳踏進深度學習領域,就保準能看到它的身影,當然更多的可能是因為它的發音奇特記住了這個名字。

然而,GAN確實是近十年以來深度學習領域乃至機器學習領域最富有創意的想法之一,它所能做到的事情也十分的令人不可思議,例如利用色塊還原圖片的黑科技:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

下面就讓我們來徹底地瞭解一下它究竟是個什麼東西。


GAN的思想:陰陽互濟

要說這GAN的思想,其實相當簡單。首先我們需要知道GAN它是拿來做什麼的?

GAN的目標就是賦予機器類似於想象的天賦。其根本作用就是生成數據。

也就是說,我們利用GAN可以去生成我們需要的數據,這對整個深度學習領域來說簡直是一個大殺器。很多時候就是因為數據不夠我們才會提出更加複雜的算法和模型以期望能夠在比較小的數據集上訓練出好的效果,可反過來越是複雜的模型可能對數據量的要求反而越大,這就讓人很難辦了。

然而GAN卻是像用深度神經網絡,去生成數據。這在之前有人做嗎?有,而且很多,包括什麼變分自編碼器之類的,但那些都是偏向傳統機器學習的方法,它的數學味道很重。然而GAN不一樣,身為新時代(深度學習時代)下的新算法,它完美地繼承了深度學習黑盒(就是你根本不知道為什麼它能取得這麼好的效果,但是它就是做到了)的特點,在對抗中完成了數據生成工作。

說了這麼多,其實GAN的主要思想就是,我們用一個生成模型生成數據,用一個鑑別模型鑑別數據,生成模型要儘可能地生成和真實數據相近的數據,從而欺騙鑑別模型,讓它分不出來哪邊是真哪邊是假

它的流程就像下面這樣:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

這就像是陰陽的兩極,它們既相互對抗,又相輔相成。

因為在訓練的過程中,

  • 我們利用生成網絡的數據和真實數據來訓練鑑別網絡,讓它鑑別能力更強;
  • 可下一個時刻,我們就會拿鑑別網絡區訓練生成網絡,讓生成網絡的欺騙性更高。

於是二者就在這樣的相互對抗中共同成長。


GAN的本質:JS散度

接下來我們要說一些硬核的東西,關於GAN的數學性的證明。雖然前文我們說GAN完美沿襲了深度學習的優良傳統,但這僅限於它的生成網絡和鑑別網絡,而整體的對抗過程卻是有嚴格的數學證明的,接下來就讓我們來看看它的真面目吧。

根據論文,我們可以知道GAN的目標函數其實就是一個關於極小化極大的二人博弈問題:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

根據上文的分析我們可以知道,整個GAN架構最重要的目的是為了能夠在對抗訓練的過程中使得生成器G生成的樣本分佈能夠接近真實樣本的分佈。

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果


最終可以證明,GANs的最小化問題,如下所示:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果


其中,

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

整理公式可以得到:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果


即:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果


上面所有的證明換成一句話就是,GAN的最優化過程其實就是在尋找生成數分佈與真實數據分佈的最小JS散度。


GAN的實現:簡單的DCGAN

這一節我們參照書和開源代碼上的內容,利用keras簡單實現一個深度卷積生成式對抗網絡(DCGAN)。其實現過程就如上文所述的一樣,我們需要分別實現一個生成網絡和一個鑑別網絡,然後把它們放在一起訓練。

具體來說,其實分成了三部分:

  • 第1部分:生成網絡模型
  • 第2部分:鑑別網絡模型
  • 第3部分:將它們整合在一起的GAN模型

流程如下:

  1. 從潛在空間中隨機抽取一些點,也就是隨機噪聲;
  2. 利用1的隨機噪聲用生成網絡生成圖像;
  3. 將生成圖像與真實圖像混合;
  4. 使用這些混合後的圖像以及相應的標籤來訓練鑑別網絡;
  5. 然後再在隨機空間中進行採樣抽取新的點;
  6. 用新的點加上全部是真實標籤來訓練GAN,也就是生成網絡和鑑別網絡的整合網絡。這裡鑑別網絡將被固定,我們相當於是在告訴生成網絡,儘量把生成的圖像讓鑑別網絡識別為真。

接下來我們來看看具體的模型結構實現。

  • 生成網絡

首先是生成網絡部分,我們接收一個100維度的向量將其映射成(32,32,3)的數據。這其中比較重要的部分就在於上採樣的過程和轉置卷積。

生成網絡的模型結構如下圖所示:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

  • 鑑別網絡

接下來是關於鑑別網絡的結構:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

我們注意到上述鑑別網絡的結構中,沒有出現池化層,這是因為最大池化運算會導致梯度的稀疏性,而這將妨礙GAN的訓練,因此在鑑別網絡的下采樣過程中,我們不採用池化而採用步進大於1的卷積層進行下采樣。

  • 對抗網絡的設置

對抗網絡將上述的生成網絡和鑑別網絡整合在了一起,更準確地說它是用來訓練生成網絡的,其具體代碼如下:

discriminator.trainable = False #在訓練gan的時候固定鑑別網絡
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input,gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr= 0.0004,clipvalue = 1.0,decay=1e-8)
gan.compile(optimizer=gan_optimizer,loss = 'binary_crossentropy')

結果

在這之後開始訓練GAN,在這個過程中慢慢地產生一些圖片,可以保存到本地進行查看,其最初生成的圖片如下所示:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

真實的圖片則是這樣的:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

可以看到圖片慢慢地向著真實圖片的樣子發展,或者說它至少在色塊和形狀上與真實圖片接近了。

然而GAN的訓練是困難的,任何參數的調整和修改甚至都可能使得GAN的訓練失敗,納什均衡的理論平衡點就像滄海中的浮萍那樣脆弱。

當然,這也是因為網絡過於簡單的緣故,受制於有限的設備,我們無法復現出比較複雜的GAN網絡,只能通過上述簡單的例子來感受GAN的具體內容。

最終,上面用來生成青蛙的GAN網絡的輸出如下所示:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

下面再使用DCGAN來嘗試生成手寫數字。我們只需要稍微修改生成網絡和鑑別網絡的部分參數,使得原來適用於32*32*3的網絡模型變成適用於28*28*1即可,一開始我們得到的手寫字生成圖片如下所示:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

我們可以看到,慢慢地生成網絡已經能夠開始生成手寫字的大致輪廓,只是難以形成系統的結構,比較模糊和分散,到訓練的後期生成的手寫字則已經可以很好地辨認了:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

真實的樣本數據如下所示:

一文玩轉GAN(對抗生成網絡),5歲紅爆AI界,其父從Google跳蘋果

與上面的生成樣本相比真實數據雖然更規整一點,但明顯生成樣本的手寫字也是真實可信的手寫字,如果將它們混起來,想來即便是通過肉眼也很難區分。


GANs的變種:各種GAN

通過前文我們已經瞭解到了GANs的基本思想和訓練方法。從原理上說GANs的思路非常巧妙也符合人們的直覺,兩個模型就像這世界上的其他事物一樣,在相互制約和對抗中演變進化;從理論上說,GANs的原生論文就已經給出了它的最優性證明,也就是上文第二小節的內容,證明了GANs本質上是最小化生成數據的分佈與真實數據分佈之間的JS散度,當算法收斂的時候,生成器刻畫的分佈就等於真實數據的分佈。

然而通過第三小節我們可以發現,在實際訓練的過程中,這個收斂非常地不穩定,往往一些參數的小變動就會導致兩個網絡訓練不平衡,最終一個網絡支配了另一個網絡使得訓練失敗。因此在GANs的發展中,衍生除了很多變種結構,想從各個方面來改進GANs的不足之處。

  • WGAN

我們知道在許多科幻小說中,對於三維空間的生物來說,二維空間的生物是不存在的;從數學上也是二維空間的物體是不存在三維空間中的體積這個概念的。遷移到我們當前的問題上來說,就是高維空間中並不是每一個點都能表達一個樣本,空間大部分都是多餘的,而真實的數據是蜷縮在低維子空間中的,它在高維空間中的表現就是一個曲面。而對於GANs來說是要在高維空間中尋找到這樣的分佈,使它和真實數據表現在高維空間中的曲面類似。而難點就在於,這個曲面就像一張極其單薄的紙飄在三維空間中一樣,根本難以察覺,如何抓住這個低維度的幽靈成為了GANs的一個研究方向。

針對前面的這個問題,WGAN於2017年被提出,它的想法就是不再使用KL散度或者說JS散度來作為目標函數,而是利用一個特殊的距離——Wasserstein距離。

從理論上解釋使用W距離更好是很複雜的,我們可以簡單理解成,現在的W距離與JS散度比較更加敏感,生成器自身分佈稍微變化一下,就能影響到它與真實分佈之間的W距離,而這對JS距離來說是比較困難的。因此使用W距離可以更加有效地鎖定捕捉到低維子空間中的真實數據分佈。

  • DCGAN

回顧上文第三節的內容,我們可以發現在實現GANs生成圖片的時候我們就已經使用了DCGAN的結構。但其實DCGAN並不是GANs結構一提出來時就有的。雖然我們現在提到圖像就會想到卷積神經網絡,但對於GANs來說利用卷積神經網絡生成圖像還走過了一段曲折的道路。

從上文可以知道,我們需要通過低維度的隨機向量生成一張具有更多數據點的圖像,也就是信息要從少到多,DCGAN的論文就提出了一些指導性的原則:

1. 去掉一切會損失位置信息的結構,比如池化層

2. 在生成網絡中需要有上採樣的結構,例如轉置卷積。

3. 生成網絡的最後不需要全連接層

4. 加入BN和ReLU激活函數

5. 如上文所說鑑別網絡需要去掉池化層,替換成步長大於1的卷積層;同時也需要用到BN和LReLU;LReLU相比ReLU,更不容易發生梯度消失。

至此,我們完成了GANs上的旅行,相信各位對GANs一定有了一個全新的瞭解。


分享到:


相關文章: