詳解 GAN 生成對抗網絡

GAN : Generative adversarial network 生成對抗網絡

詳解 GAN 生成對抗網絡

https://www.kdnuggets.com/2017/01/generative-adver

Yan Lecun 給這個模型很高評價,認為它是機器學習領域緊十年來最酷的模型。

關於 GAN 的論文就有好多,下面這個repo裡面比較全的列出了相關論文:

https://github.com/hindupuravinash/the-gan-zoo/blob/master/README.md

從2017年開始關於它的論文每個月都在不斷大幅增長:

https://deephunt.in/the-gan-zoo-79597dc8c347


GAN 主要是用來生成東西,

在圖像領域是生成圖像,給它一個隨機的向量,這個向量的每個元素一般來說是代表圖像的一種特徵,輸入給模型後,它可以生成一張圖片,也就是一個高維向量,向量的每個維度對應一個像素的顏色,

在自然語言處理領域是生成文字,比如說寫詩寫文章,給它一個隨機向量,它可以輸出一句話。


GAN 模型包括一個 generator 生成器和一個 discriminator 辨別器,生成器和辨別器之間的關係就好像是被捕食者和捕食者的關係。:


詳解 GAN 生成對抗網絡


首先看 discriminator,實際上它是一個神經網絡。

它的輸入是一張圖片,輸出是一個標量,這個值代表圖片的質量,數值越大,生成的結果質量越高,所以,當輸入圖片很真實的時候,discriminator 給它的得分越高,相反則得分越低。

再看 generator,它也是一個神經網絡,開始它的輸入是隨機的,因為它也不知道要怎麼生成圖片,所以一開始的輸出也是比較模糊的東西。

生成器生成的結果,辨別器要做的就是判斷這張圖片是由生成器生成的還是像是真實的圖片,辨別器會給評分,評分低的話會被督促著進化,生成更好地結果,就像被捕食者為了不被滅族就要進化,但是捕食者為了不被餓著也會進化,就這樣互相督促著一點點改進,最後會生成非常好的結果。

所以第一代生成器生成第一代結果,第一代辨別器評分,

然後第二代生成器要做的事,就是想辦法騙過第一代的辨別器。例如,第一代的辨別器通過是否有顏色這個特徵來區分真實圖片和第一代生成器生成的圖片,那麼第二代生成器為了騙過第一代的辨別器就會給圖片加上了顏色。

同樣第二代的辨別器也跟著進化,它要判斷真實圖片和第二代生成器生成的圖片,這時候不能根據是否有顏色了,而是通過其他特徵,例如是否有嘴巴。

就這樣生成器和辨別器之間的關係就像是相互對抗的天敵,經過不斷地進化,生成器就可以生成更高質量更接近真實的圖片。


GAN 模型的算法過程

生成器和辨別器都是神經網絡,訓練模型之前先隨機生成它們的參數,然後進行迭代去訓練生成器和辨別器。

在每個迭代中有兩個步驟:

第一步,固定生成器的參數,只去訓練辨別器的參數。

  • 具體做法是將一些隨機向量投給生成器,生成器就會生成一些效果不好的圖片,
  • 然後從真實圖片庫中採樣一些樣本,
  • 接著要去訓練辨別器的參數,

方法就是,如果這個圖片是從真實數據集合中採用出來的,就給高分,如果是生成器生成的,就給低分,這可以是一個分類問題。

第二步,固定辨別器,只去訓練生成器的參數。

  • 先把一個向量輸入給生成器,會生成一個圖片,
  • 接著將這個圖片輸入給辨別器,辨別器會給這個圖片一個分數,
  • 因為生成器的目的是要騙過辨別器,所以希望得到的分數可以越高越好,相當於生成的圖片過了辨別器這一關,生成了比較真實的圖片,也就是這時候要固定辨別器的參數,去調節生成器的參數。

在實際訓練時,會將生成器和辨別器放在一起,組成一個大的神經網絡。

例如,生成器和辨別器都有五層,將它們連在一起成為一個十層的網絡,

輸入是一個向量,輸出是一個值,中間有一層輸出代表一個圖片,這一層會特別寬,和圖片的展開緯度是一樣的。

在訓練的時候,先固定後面五層隱藏層,只去訓練前面五層,就是在訓練生成器,目標就是要讓整個網絡的輸出值越大越好。


詳解 GAN 生成對抗網絡


---

下面這個圖就是 GAN 的詳細算法:


詳解 GAN 生成對抗網絡


接下來對照算法詳細講解,

生成器的參數是 theta g,辨別器的參數是 theta d。

  • 在每次迭代中,先從數據庫中採樣出 m 個圖片,
  • 再從一個分佈中採樣出 m 個噪音樣本向量,這個分佈可以是高斯分佈,
  • x ~ 表示生成器生成的圖片,
  • 然後去調整辨別器,
  • 前一部分是訓練辨別器,目的是要讓這個目標函數越大越好,

目標函數的意義是,首先拿出 m 張真實的圖片,給辨別器得到一個分數,取 log 對數,在做平均,因為目標是要讓這個目標函數越大越好,就是讓第一項越大越好,也就是讓他給真實的圖片的得分越大越好。

目標函數第二項的含義是將生成器生成的這些假的圖片傳遞給辨別器,經過 sigmoid 函數,得到 0~1 之間的數,同樣也是取對數,求平均。

因為整體是希望這個目標函數越大越好,那麼第二項中的 D(X) 就是需要越小越好,也就是生成器生成的圖片得分越小越好。

至於如何讓這個目標函數越大越好,就可以用梯度算法等優化算法來更新參數。

算法的上面這部分代碼是要訓練辨別器,下面這部分是要訓練生成器。

  • 首先採樣出 m 個向量,這些向量和前面的採樣的向量是不一樣的,
  • 生成器的目標是想辦法騙過辨別器,V ~ 就是生成器的目標函數,

把 m 個向量輸入給生成器,生成一張圖片就是 G(Z),

把這個圖片丟給辨別器,得到的是 D(G(Z)),同樣取對數,求平均值。

最終的目標是希望這個函數越大越好,意思是生成器生成的圖片輸入給辨別器之後得到的分數可以越大越好。

同樣可以用梯度的算法來調節參數,使得目標函數越大越好。

這樣上面一部分是訓練辨別器,下面是訓練生成器,就這樣反覆交替地執行這兩個步驟。


推薦學習資料:

https://youtu.be/DQNNMiAP5lw

本文是 李宏毅 GAN Lecture 1 Introduction 的學習筆記。


分享到:


相關文章: