RNN入門:MNIST手寫數字識別(含代碼)

遞歸神經網絡(RNN)是一種人工神經網絡,其中節點之間的連接形成沿序列的有向圖。這允許它展示時間序列的時間動態行為。與前饋神經網絡不同,RNN可以使用其內部狀態(存儲器)來處理輸入序列。這使得它們適用於諸如未分段,連接手寫字識別或語音識別之類的任務。

MINIST是帶標籤的28*28的手寫數字圖片,被稱為圖像識別上的“果蠅”,類似於機器學習中的IRIS數據集,是圖像識別領域最常用的入門數據。本文將用Tensorflow內建的Keras API訓練MINIST數據集。

  • 導入模塊

需要導入TensorFlow內建立的Keras模塊及其組件

import tensorflow as tf
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Dropout, LSTM
RNN入門:MNIST手寫數字識別(含代碼)

  • 載入並標準化數據

可以看出,我們有6萬個28*28的訓練數據,1萬個測試數據。

mnist = tf.keras.datasets.mnist 
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
#標準化數據
x_train = x_train/255.0
x_test = x_test/255.0
print(x_train.shape)
print(x_train[0].shape)
RNN入門:MNIST手寫數字識別(含代碼)

  • 建立模型
model = Sequential()
model.add(LSTM(128, input_shape=(x_train.shape[1:]), activation='relu', return_sequences=True))
  • 添加LSTM、Dense層和Dropout層
model.add(Dropout(0.2))
model.add(LSTM(128, activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))
  • 設定優化器
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-6)
  • 設定模型的Loss函數、優化器以及用來判斷模型好壞的依據(metrics)
model.compile(
 loss='sparse_categorical_crossentropy',
 optimizer=opt,
 metrics=['accuracy'],
)
  • 訓練模型
model.fit(x_train,
 y_train,
 epochs=3,
 validation_data=(x_test, y_test))
RNN入門:MNIST手寫數字識別(含代碼)

  • 驗證模型
score = model.evaluate(x_test, y_test, verbose=0)
  • 輸出結果
print('測試損失度:', score[0])
print('測試準確率:', score[1])
RNN入門:MNIST手寫數字識別(含代碼)

可以看出,準確率達到96.87%。


分享到:


相關文章: