知識蒸餾(Distilling the Knowledge in a Neural Network)

作者:Geoffrey Hinton,Oriol Vinyals,Jeff Dean

發表信息:Machine Learning (cs.LG); Neural and Evolutionary Computing (cs.NE)

一、問題動機

神經網絡訓練階段從大量數據中獲取網絡模型,訓練階段可以利用大量的計算資源且不需要實時響應。然而到達使用階段,神經網絡需要面臨更加嚴格的要求包括計算資源限制,計算速度要求等等。

一個複雜的網絡結構模型是若干個單獨模型組成的集合,或者是一些很強的約束條件下(比如dropout率很高)訓練得到的一個很大的網絡模型。一旦複雜網絡模型訓練完成,我們便可以用另一種訓練方法:“蒸餾”,把我們需要配置在應用端的縮小模型從複雜模型中提取出來。“蒸餾”的難點在於如何縮減網絡結構但是把網絡中的知識保留下來。知識就是一幅將輸入向量導引至輸出向量的地圖。做複雜網絡的訓練時,目標是將正確答案的概率最大化,但這引入了一個副作用:這種網絡為所有錯誤答案分配了概率,即使這些概率非常小。我們將複雜模型轉化為小模型時需要注意保留模型的泛化能力,一種方法是利用由複雜模型產生的分類概率作為“軟目標”來訓練小模型。在轉化階段,我們可以用同樣的訓練集或者是另外的“轉化”訓練集。當複雜模型是由簡單模型複合而成時,我們可以用各自的概率分佈的代數或者幾何平均數作為“軟目標”。當“軟目標的”熵值較高時,相對“硬目標”,它每次訓練可以提供更多的信息和更小的梯度方差,因此小模型可以用更少的數據和更高的學習率進行訓練。

二、解決思路:

蒸餾大致描述如下圖:

知識蒸餾(Distilling the Knowledge in a Neural Network)

cumbersome model表示複雜的大模型,distilled model表示經過knowledge distillation後學習得到的小模型,hard targets表示輸入數據所對應的label ,例如[0,0,1,0]。soft targets表示輸入數據通過大模型(cumbersome model)所得到的softmax層的輸出,例如[0.01,0.02,0.98,0.17]。

Softmax公式:

知識蒸餾(Distilling the Knowledge in a Neural Network)

qi 表示第 i 類的輸出概率,zi、zj 表示 softmax 層的輸入(即 logits),T 為溫度係數,用來控制輸出概率的soft程度。

論文方法的關鍵之處便是利用soft target來輔助hard target一起訓練。

由於hard target 包含的信息量(信息熵)很低,而soft target包含的信息量大,擁有不同類之間關係的信息。比如同時分類驢和馬的時候,儘管某張圖片是馬,但是soft target就不會像hard target 那樣只有馬的index處的值為1,其餘為0,而是在驢的部分也會有概率。

這樣做的好處是,這個圖像可能更像驢,而不會去像汽車或者狗之類的,而這樣的soft信息存在於概率中,以及label之間的高低相似性都存在於soft target中。但是如果soft targe是像這樣的信息[0.98 0.01 0.01],就意義不大了,所以需要在softmax中增加溫度參數T(這個設置在最終訓練完之後的推理中是不需要的)

T 的意義可以用如下圖 來理解,圖中 紅,綠,藍 分別對用同一組z在T為(5,25,50)下的值,可以看出,T越大,值之間的差就越小(折線更平緩,即更加的 soft),但是相對的大小關係依然沒變。

知識蒸餾(Distilling the Knowledge in a Neural Network)

目標函數由以下兩項的加權平均組成:

soft targets 和小模型的輸出數據的交叉熵(保證小模型和大模型的結果儘可能一致)

hard targets 和小模型的輸出數據的交叉熵(保證小模型的結果和實際類別標籤儘可能一致)

知識蒸餾(Distilling the Knowledge in a Neural Network)

算法示意圖:

知識蒸餾(Distilling the Knowledge in a Neural Network)

1、訓練大模型:先用hard target,也就是正常的label訓練大模型。

2、計算soft target:利用訓練好的大模型來計算soft target。也就是大模型“軟化後”再經過softmax的output。

3、訓練小模型,在小模型的基礎上再加一個額外的soft target的loss function,通過λ來調節兩個loss functions的比重。

4、預測時,將訓練好的小模型按常規方式(如上右圖)使用。

三、方法亮點:

利用大模型提取先驗知識,將這種先驗知識作為soft target讓小模型學習

四、主要結果:

1、初步試驗 Mnist數據集

訓練一個有兩層具有1200個單元的隱藏層的大型網絡(使用dropout和weight-constraints作為正則)值得注意的一點是dropout可以看做是share weights 的ensemble models;

另外一個小一點的網絡具有兩層800個單元隱藏層沒有正則

訓練結果:第一個網絡test error 67個,第二個是146個;再加入soft target並且T設置為20之後小型網絡test error達到74個

2、在語音識別數據上的實驗

知識蒸餾(Distilling the Knowledge in a Neural Network)

3、在大規模數據集上的實驗


分享到:


相關文章: