深度学习-LSTM算法实现(MNIST手写数字识别)

MNIST数据集是机器学习入门的经典数据集,本文将以MNIST手写数字数据集为例,使用深度学习方法,训练手写数字识别模型。对想进行深度学习的同学来说是非常好的练手例子,全文代码关键点都有注释,自行练习时可以尝试修改其中的迭代次数和训练精度,以感受其训练过程。

MNIST数据集链接:http://yann.lecnn.com/exdb/mnist ,共包括4个(.gz)压缩文件。

深度学习-LSTM算法实现(MNIST手写数字识别)

MNIST官方网站内容

下载完之后在home或其他地方新建文件夹,

使用:gzip -d [filename] 指令依次解压4个文件

代码环境:Ubuntu18.04,Pycharm,TensorFolw2.0

深度学习-LSTM算法实现(MNIST手写数字识别)

手写数字图片示例,表示第21429张,数字为1

接下来进入正题,检查数据集,查看数据类型等,为训练做准备。

第1步 查看数据集的内容及大小。

<code>from tensorflow.examples.tutorials.mnist import input_data/<code>

Tensorflow中对MNIST数据集专门的封装,方便数据处理。

<code>data_dir = "/home/name/Desktop/mnist1"
mnist = input_data.read_data_sets(data_dir,one_hot=True)/<code>

分别为数据集文件路径和读取MNIST数据的函数。

<code>print(mnist.train.images.shape)  #训练数据大小
print(mnist.train.labels.shape) #标签
print(mnist.test.images.shape)
print(mnist.test.labels.shape)/<code>

如果以上配置正确则不会报错,并输出以下结果(注意运行中会调用Tensorflow出现警告等信息可以忽略):

(55000,784)

(55000,10)

(10000,784)

(10000,10)

以上数据表示训练集手写数字图片有55000张,大小为784(28*28)个像素点,标签为(0-9)10个数字,类型同样于测试集。但需要注意的是他们都是将图片展开后的一维向量。

第2步 查看数据集

如果想查看一下里面某张图片的数字是多少?他的标签是多少?这里也有一小段代码可以顺序查看所有数据集中的图片,并把图片的编号和代表的数字显示出来,这里要用到matplotlib图形库函数和numpy函数库(需注意参数给小一些,看几张就可以了),代码如下:

<code>from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
import random
data_dir = "/home/quan/Desktop/mnist1"
mnist = input_data.read_data_sets(data_dir,one_hot=True)
n = 20
for i in range(n):
print("The number of picture is %i !"%i)
plt.imshow(mnist.train.images[i].reshape((28,28)),cmap='gray')
plt.title("%i"%np.argmax(mnist.train.labels[i]))
print(np.argmax(mnist.train.labels[i]))
time.sleep(1)
plt.show()
print("Finished!")/<code>

其中函数np.argmax()是取一组数据中的最大值。

第3步 使用RNN循环神经网络训练模型

由于循环神经网络每个时刻读取图片中的1行,即每个时刻需要读取的数据向量长度为28,那么读完整张图片需要读取28行。

LSTM结构搭建:

(1) 定义输入、输出placeholder:

<code>tf_x = tf.placeholder(tf.float32,[None,TIME_STEP*INPUT_SIZE])
image = tf.reshape(tf_x,[-1,TIME_STEP,INPUT_SIZE])
tf_y = tf.placeholder(tf.int32,[None,10])/<code>

(2) 定义LSTM结构:

<code>rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs,(h_c,h_n) = tf.nn.dynamic_rnn(
rnn_cell,
image,
initial_state = None,
dtype = tf.float32,
time_major =False
)
output = tf.layers.dense(outputs[:,-1,:],10)/<code>

(3) 定义代价函数:

<code>loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y,logits=output)/<code>

(4) 定义训练过程及训练精度:

<code>LR =0.01  #定义学习效率
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
tf.metrics.accuracy(labels=tf.argmax(tf_y,axis=1),predictions=tf.argmax(output,axis=1),)[1]/<code>

第4步 开始完整的训练过程

<code>from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time
import random
data_dir = "/home/quan/Desktop/mnist1"
mnist = input_data.read_data_sets(data_dir,one_hot=True)
print(time.asctime())
tf.set_random_seed(1)
np.random.seed(1)
#定义超参数
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE =28
LR =0.01 #定义学习效率
a =0

#读入数据
test_x = mnist.test.images[:2000]
test_y = mnist.test.labels[:2000]
print(mnist.train.images.shape)
print(mnist.train.labels.shape)
b = mnist.train.images.shape[0]
# print(b)
i = random.randint(0,b)
print("The picture is %i."%i)
plt.imshow(mnist.train.images[i].reshape((28,28)),cmap="gray")
plt.title("$The number is %i, num=%i$"%(i,np.argmax(mnist.train.labels[i])))
plt.show()
#定义表示x的向量的tensorflow placeholder
tf_x = tf.placeholder(tf.float32,[None,TIME_STEP*INPUT_SIZE])
image = tf.reshape(tf_x,[-1,TIME_STEP,INPUT_SIZE])
tf_y = tf.placeholder(tf.int32,[None,10])
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs,(h_c,h_n) = tf.nn.dynamic_rnn(
rnn_cell,
image,
initial_state = None,
dtype = tf.float32,
time_major =False
)
output = tf.layers.dense(outputs[:,-1,:],10)
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y,logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y,axis=1),predictions=tf.argmax(output,axis=1),)[1]
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
num = 0
for step in range(12000):
b_x,b_y = mnist.train.next_batch(BATCH_SIZE)
_,loss_ = sess.run([train_op,loss],{tf_x:b_x,tf_y:b_y})
# print("All steps is %step..."%step)
if step % 50 == 0:
num +=1
print("The steps is: %d"%num)
accuracy_ = sess.run(accuracy,{tf_x:test_x,tf_y:test_y})
print("train loss:%.6f"%loss_,"|test accuracy:%.6f" %accuracy_)
test_output = sess.run(output,{tf_x:test_x[:100]})
pred_y = np.argmax(test_output,1)
print(pred_y,"prediction number.")
print(np.argmax(test_y[:100],1),"real number.")/<code>

说明:其中在训练开始时添加了时间戳,用到time时间模块,并且训练开始时从训练集中随机选取一张手写数据图片将其位置和代表的数字显示出来,便于知道数据读取是不是正常。


深度学习-LSTM算法实现(MNIST手写数字识别)

程序运行时随机显示出来的图片

代码中将学习率设为0.01,训练数据设置为12000个,没有全部用主要耗费时间,将每50个为一组迭代完成后显示迭代次数和当前的训练的精度,精度保留了6位小数,最终训练完成后精度达到0.965615,如下图训练中:迭代了350张图片可以达到的精度0.744857。

深度学习-LSTM算法实现(MNIST手写数字识别)

LSTM训练过程中

选取了测试集数据中的100个进行了测试,非常的准确。训练结果如下图所示:

深度学习-LSTM算法实现(MNIST手写数字识别)

最终训练结果

有没有get到,快去动手练习吧。


分享到:


相關文章: