深度|SGD過程中的噪聲如何幫助避免局部極小值和鞍點?

選自noahgolmant

機器之心編譯

參與:Geek AI、劉曉坤

來自 UC Berkeley RISELab 的本科研究員 Noah Golmant 發表博客,從理論的角度分析了損失函數的結構,並據此解釋隨機梯度下降(SGD)中的噪聲如何幫助避免局部極小值和鞍點,為設計和改良深度學習架構提供了很有用的參考視角。

當我們著手訓練一個很酷的機器學習模型時,最常用的方法是隨機梯度下降法(SGD)。隨機梯度下降在高度非凸的損失表面上遠遠超越了樸素梯度下降法。這種簡單的爬山法技術已經主導了現代的非凸優化。然而,假的局部最小值和鞍點的存在使得分析工作更加複雜。理解當去除經典的凸性假設時,我們關於隨機梯度下降(SGD)動態的直覺會怎樣變化是十分關鍵的。向非凸環境的轉變催生了對於像動態系統理論、隨機微分方程等框架的使用,這為在優化解空間中考慮長期動態和短期隨機性提供了模型。

在這裡,我將討論在梯度下降的世界中首先出現的一個麻煩:噪聲。隨機梯度下降和樸素梯度下降之間唯一的區別是:前者使用了梯度的噪聲近似。這個噪聲結構最終成為了在背後驅動針對非凸問題的隨機梯度下降算法進行「探索」的動力。

mini-batch 噪聲的協方差結構

介紹一下我們的問題設定背景。假設我想要最小化一個包含 N 個樣本的有限數據集上的損失函數 f:R^n→R。對於參數 x∈R^n,我們稱第 i 個樣本上的損失為 f_i(x)。現在,N 很可能是個很大的數,因此,我們將通過一個小批量估計(mini-batch estimate)g_B:深度|SGD過程中的噪聲如何幫助避免局部極小值和鞍點?來估計數據集的梯度 g_N:

深度|SGD過程中的噪聲如何幫助避免局部極小值和鞍點?。其中,B⊆{1,2,…,N} 是一個大小為 m 的 mini-batch。儘管 g_N 本身就是一個關於梯度 ∇f(x) 的帶噪聲估計,結果表明,mini-batch 抽樣可以生成帶有有趣的協方差結構的估計。

引理 1 (Chaudhari & Soatto 定理:https://arxiv.org/abs/1710.11029):在回置抽樣(有放回的抽樣)中,大小為 m 的 mini-batch 的方差等於 Var(g_B)=1/mD(x),其中

深度|SGD過程中的噪聲如何幫助避免局部極小值和鞍點?

該結果意味著什麼呢?在許多優化問題中,我們根本的目標是最大化一些參數配置的似然。因此,我們的損失是一個負對數似然。對於分類問題來說,這就是一個交叉熵。在這個例子中,第一項 深度|SGD過程中的噪聲如何幫助避免局部極小值和鞍點?是對於(負)對數似然的梯度的協方差的估計。這就是觀測到的 Fisher 信息。當 N 趨近於正無窮時,它就趨向於一個 Fisher 信息矩陣,即相對熵(KL 散度)的 Hessian 矩陣。但是 KL 散度是一個與我們想要最小化的交叉熵損失(負對數似然)相差甚遠的常數因子。

因此,mini-batch 噪聲的協方差與我們損失的 Hessian 矩陣漸進相關。事實上,當 x 接近一個局部最小值時,協方差就趨向於 Hessian 的縮放版本。

繞道 Fisher 信息

在我們繼續詳細的隨機梯度下降分析之前,讓我們花點時間考慮 Fisher 信息矩陣 I(x) 和 Hessian 矩陣 ∇^2f(x) 之間的關係。I(x) 是對數似然梯度的方差。方差與損失表面的曲率有什麼關係呢?假設我們處在一個嚴格函數 f 的局部最小值,換句話說,I(x∗)=∇^2f(x∗) 是正定的。I(x) 引入了一個 x∗附近的被稱為「Fisher-Rao metric」的度量指標: d(x,y)=√[(x−y)^TI(x∗)(x−y) ]。有趣的是,參數的 Fisher-Rao 範數提供了泛化誤差的上界(https://arxiv.org/abs/1711.01530)。這意味著我們可以對平坦極小值的泛化能力更有信心。

回到這個故事中來

接下來我們介紹一些關於隨機梯度下降動態的有趣猜想。讓我們做一個類似中心極限定理的假設,並且假設我們可以將估計出的 g_B 分解成「真實」的數據集梯度和噪聲項:g_B=g_N+(1√B)n(x),其中 n(x)∼N(0,D(x))。此外,為了簡單起見,假設我們已經接近了極小值,因此 D(x)≈∇^2f(x)。n(x) 在指數參數中有一個二次形式的密度ρ(z):

深度|SGD过程中的噪声如何帮助避免局部极小值和鞍点?

這表明,Hessian 矩陣的特徵值在決定被隨機梯度下降認為是「穩定」的最小值時起重要的作用。當損失處在一個非常「尖銳」(二階導很大)的最小值,並且此處有許多絕對值大的、正的特徵值時,我很可能會加入一些把損失從樸素梯度下降的吸引域中「推出來」的噪聲。類似地,對於平坦極小值,損失更有可能「穩定下來」。我們可以用下面的技巧做到這一點:

引理 2:令 v∈R^n 為一個均值為 0 並且協方差為 D 的隨機向量。那麼,E[||v||^2]=Tr(D)。


分享到:


相關文章: