教你編寫一個機器學習代碼也能使用的單元測試

摘要: 想不想節省重新訓練數據的時間?想不想讓你的研究成果有個質的飛躍?來看看這些單元測試,助你一臂之力。

教你編寫一個機器學習代碼也能使用的單元測試

注:這篇文章自從發佈出來,就受到讀者的好評和關注,因此,我編寫了一個機器學習測試庫!

在過去的一年裡,我花了很多時間來研究深度學習,並且也犯過很多錯誤,這些錯誤不僅幫助我對機器學習有了更加深入的理解,也讓我學會了如何正確合理的設計這些系統。在Google Brain工作期間,我學到了很多設計原則,其中之一就是單元測試可以制定或打破原有的算法,並且能夠節省數週的調試和訓練時間。

然而,到目前為止,似乎還沒有為神經網絡代碼編碼測試單元的比較可靠的教程。即使是在OpenAI上,也只是通過一行行的盯著代碼來發現bug,然後再思考導致這一bug的原因到底是什麼。顯然,大部分人都不願意這麼耗費時間,因此,我希望看完這個教程,你就可以開始著手測試你的系統!

我們從一個簡單的例子開始:試試在這段代碼中找到bug。

教你編寫一個機器學習代碼也能使用的單元測試

有找到bug嗎?實際上,這個神經網絡並沒有進行堆疊。我在編寫代碼的時候,只是對slim.conv2d(...)代碼行做了簡單的複製粘貼,然後對內核大小進行修改,而並沒有實際的輸入。

略微尷的來說,這其實是我上週編寫的代碼……這也是個很重要的教訓!但是由於某些原因,這些bug很難被發現:

1.這段代碼永遠不會崩潰,或者引發錯誤,又或者是運行速度變慢。

2.這個神經網絡仍在訓練,並且損失函數會越來越小。

3.幾個小時後,會收斂到某一數值,結果非常糟糕,但是,你又不知道應該修改哪裡。

當唯一的反饋只有最終那個錯誤驗證時,那麼,你只有一個辦法——就是搜索整個網絡架構。不用再多說了,你需要的是一個更好的網絡系統。

在我們對數據進行了一整天的訓練以後,該如何發現這一bug呢? 我們發現,最容易注意到的是,層的值實際上從未到達函數外的任何其他張量。因此,假設我們有某種類型的損失函數和優化器,這些張量永遠都不會得到優化,它們將始終保持為默認值。

通過簡單的訓練,我們來比較訓練之前和訓練之後的結果:

教你編寫一個機器學習代碼也能使用的單元測試

在這不到15行的代碼中,我們基本上驗證了訓練過的所有的變量。

這個測試非常簡單、實用。現在,假設我們已經修復了上一個問題,現在,添加一些批量優化,看看是否能發現這一bug。

教你編寫一個機器學習代碼也能使用的單元測試

看到了沒?這個非常微妙。在tensorflow中,batch_norm實際上將is_training默認為False,所以添加這行代碼並不能在訓練期間將輸入規範化!值得慶幸的是,我們編寫的最後一個單元測試將會立刻找到這個問題!

我們來看看另外一個例子,來自於reddit的一個帖子:該作者想創建一個分組器,其輸出範圍為(0,1),你是否能夠找出其中的bug?

教你編寫一個機器學習代碼也能使用的單元測試

這個bug很難發現,並且稍不注意就會導致特別混亂的結果。基本上,這個預測只有一個輸出,當你使用softmax交叉熵時,總會導致損失函數為0。

測試這段代碼最簡單的方法就是——確保損失函數永遠不為0。

教你編寫一個機器學習代碼也能使用的單元測試

這個測試類似於我們的第一個測試,唯一不同的就是回退。在這個測試中,你可以確保只訓練你想要訓練的變量。拿生成對抗網絡來(GAN)說,常常出現的bug就是忘記在優化期間訓練了哪些變量,類似這種的bug經常會發生。

教你編寫一個機器學習代碼也能使用的單元測試

這其中最大的問題就是:優化器有一個默認設置來優化所有的變量。對於類似於對抗生成網絡的架構來說,這是對所有訓練時間判了一個死刑。在這裡,使用下面的測試代碼,你就可以輕鬆檢測到這些bug:

教你編寫一個機器學習代碼也能使用的單元測試

同樣,我們也可以為鑑別器或其它強化學習算法編寫類似的測試代碼。很多演員-評論模型都有自己相對獨立的網絡,需要通過不同的損失進行優化。

為了你在閱讀完本文後,能夠更好的進行測試,我認為以下幾個建議很重要:

1.保證測試的確定性。如果你真的想要隨機輸入數據,那麼,請確保輸入的隨機性,以便於輕鬆的完成測試。

2.保證測試的簡短性。一定要有能夠訓練收斂並檢查驗證集的單元測試,否則你就是在浪費時間。

3.確保在每次測試前重置圖表。

總之,還會有很多測試方法可以測試這些算法。花一個小時的時間來編寫一個測試代碼,不僅可以幫你節省重新訓練的時間,還能夠大大改善你的研究成果!

以上為譯文。

本文由阿里云云棲社區組織翻譯。


分享到:


相關文章: