用純NumPy碼一個RNN、LSTM:這是最好的入門方式了

隨著 TensorFlow 和 PyTorch 等框架的流行,很多時候搭建神經網絡也就調用幾行 API 的事。大多數開發者對底層運行機制,尤其是如何使用純 NumPy 實現神經網絡變得比較陌生。以前機器之心曾介紹過如何使用 NumPy 實現簡單的卷積神經網絡,但今天會介紹如何使用 NumPy 實現 LSTM 等循環神經網絡。

用純NumPy碼一個RNN、LSTM:這是最好的入門方式了

一般使用純 NumPy 實現深度網絡會面臨兩大問題,首先對於前向傳播,卷積和循環網絡並不如全連接網絡那樣可以直觀地實現。為了計算性能,實踐代碼與理論之間也有差別。其次,我們實現了前向傳播後還需要繼續實現反向傳播,這就要求我們對矩陣微分和鏈式法則等數學基礎都有比較充足的瞭解。

儘管 NumPy 不能利用 GPU 的並行計算能力,但利用它可以清晰瞭解底層的數值計算過程,這也許就是為什麼 CS231n 等課程最開始都要求使用 NumPy 手動實現深度網絡吧。

項目地址:https://github.com/krocki/dnc

在這個項目中,作者主要使用 NumPy 實現了 DNC、RNN 和 LSTM,其中 RNN 代碼借鑑了 A.Karpathy 以前寫過的代碼。此外,作者還寫了 Gradient check 以確定實現的正確性,是不是感覺自深度學習框架流行以來,梯度檢驗這個詞就漸漸消失了~

具體而言,這個項目是 DeepMind 於 2016 年發表在 Nature 的論文《Hybrid computing using a neural network with dynamic external memory》的實現,即可微神經計算機(DNC),其示例的任務是字符級預測。repo 中還包括 RNN(rnn-numpy.py) 和 LSTM (lstm-numpy.py) 的實現,一些外部數據(ptb, wiki)需要分別下載。

如下所示為 LSTM 的前向傳播過程,Pyhon 2.7 的 xrange 改成 range 就好了 ˉ\(ツ)/ˉ:

 loss = 0
# forward pass
for t in xrange(len(inputs)):
# encode in 1-of-k representation
xs[t] = np.zeros((M, B))
for b in range(0,B): xs[t][:,b][inputs[t][b]] = 1
# gates, linear part
gs[t] = np.dot(Wxh, xs[t]) + np.dot(Whh, hs[t-1]) + bh
# gates nonlinear part
#i, o, f gates
gs[t][0:3*HN,:] = sigmoid(gs[t][0:3*HN,:])
#c gate
gs[t][3*HN:4*HN, :] = np.tanh(gs[t][3*HN:4*HN,:])
#mem(t) = c gate * i gate + f gate * mem(t-1)
cs[t] = gs[t][3*HN:4*HN,:] * gs[t][0:HN,:] + gs[t][2*HN:3*HN,:] * cs[t-1]
# mem cell - nonlinearity
cs[t] = np.tanh(cs[t])
# new hidden state
hs[t] = gs[t][HN:2*HN,:] * cs[t]
# unnormalized log probabilities for next chars
ys[t] = np.dot(Why, hs[t]) + by
###################
mx = np.max(ys[t], axis=0)
# normalize
ys[t] -= mx
# probabilities for next chars
ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t]), axis=0)
for b in range(0,B):
# softmax (cross-entropy loss)
if ps[t][targets[t,b],b] > 0: loss += -np.log(ps[t][targets[t,b],b])

如上代碼所示,最外層的循環 t 表示不同的時間步。而在每一個時間步下,首先需要計算不同的門控激活值,這三個門都是並在一起算的,這和我們在理論上看到的三個獨立公式不太一樣,但很合理。接下來按照 LSTM 單元的計算過程依次算出當前記憶內容 cs[t]、隱藏單元輸出值 hs[t] 和最後的概率預測 ys[t]。最後只需要根據預測算損失值,並加入總體損失就行了。

除了上述的前向傳播,更厲害的還是 RNN 和 LSTM 等的反向傳播,即沿時間的反向傳播(BPTT),這裡就需要讀者具體參考代碼並測試了。

項目的使用

除了讀源碼外,當然我們也可以通過命令行直接試用模型效果,首先檢驗梯度等關鍵結構與代碼:

python dnc-debug.py

下面的版本都是準備好的:

python rnn-numpy.py
python lstm-numpy.py
python dnc-numpy.py

該項目具有這些特點:數值計算僅依賴於 NumPy、添加了批處理、可將 RNN 修改為 LSTM,還能進行梯度檢查。

該項目已經實現了 LSTM-控制器,2D 內存數組和內容可尋址的讀/寫。但有一個問題是,關鍵相似度的 softmax 會導致崩潰(除以 0),如果遇到這種情況,需要重新啟動。該 repo 還有一些需要完成或改進的地方,包括動態內存分配和釋放,實現更快、可保存的模型等。

在採樣輸出時,我們可以得到的數據包括時間、迭代次數、BPC(預測誤差->每字符的位數,越低越好),以及處理速度(char/s)。

0: 4163.009 s, iter 104800, 1.2808 BPC, 1488.38 char/s

如下展示了反向傳播的數值梯度檢驗(最右邊列的值應該小於 1e-4),中間列是計算得到的分析和數值梯度範圍(這些應該或多或少都能匹配上)。

GRAD CHECK
Wxh: n = [-1.828500e-02, 5.292866e-03] min 3.005175e-09, max 3.505012e-07
a = [-1.828500e-02, 5.292865e-03] mean 5.158434e-08 # 10/4
Whh: n = [-3.614049e-01, 6.580141e-01] min 1.549311e-10, max 4.349188e-08
a = [-3.614049e-01, 6.580141e-01] mean 9.340821e-09 # 10/10
Why: n = [-9.868277e-02, 7.518284e-02] min 2.378911e-09, max 1.901067e-05
a = [-9.868276e-02, 7.518284e-02] mean 1.978080e-06 # 10/10
Whr: n = [-3.652128e-02, 1.372321e-01] min 5.520914e-09, max 6.750276e-07
a = [-3.652128e-02, 1.372321e-01] mean 1.299713e-07 # 10/10
Whv: n = [-1.065475e+00, 4.634808e-01] min 6.701966e-11, max 1.462031e-08
a = [-1.065475e+00, 4.634808e-01] mean 4.161271e-09 # 10/10
Whw: n = [-1.677826e-01, 1.803906e-01] min 5.559963e-10, max 1.096433e-07
a = [-1.677826e-01, 1.803906e-01] mean 2.434751e-08 # 10/10
Whe: n = [-2.791997e-02, 1.487244e-02] min 3.806438e-08, max 8.633199e-06
a = [-2.791997e-02, 1.487244e-02] mean 1.085696e-06 # 10/10
Wrh: n = [-7.319636e-02, 9.466716e-02] min 4.183225e-09, max 1.369062e-07
a = [-7.319636e-02, 9.466716e-02] mean 3.677372e-08 # 10/10
Wry: n = [-1.191088e-01, 5.271329e-01] min 1.168224e-09, max 1.568242e-04
a = [-1.191088e-01, 5.271329e-01] mean 2.827306e-05 # 10/10
bh: n = [-1.363950e+00, 9.144058e-01] min 2.473756e-10, max 5.217119e-08
a = [-1.363950e+00, 9.144058e-01] mean 7.066159e-09 # 10/10
by: n = [-5.594528e-02, 5.814085e-01] min 1.604237e-09, max 1.017124e-05
a = [-5.594528e-02, 5.814085e-01] mean 1.026833e-06 # 10/10


分享到:


相關文章: