机器学习之向前传播神经网络:手写数字识别
- 简述
- 神经网络
- 训练数据部分手写数字展示
- 逻辑函数
- 神经网络多分类结果预测
简述
- 上一篇用逻辑回归实现多分类,该篇将用神经网络进行手写数字的识别。
- 该练习将以octave 作为工具进行实验,逻辑回归数学公式及讲解文档在这里,可点击访问。
- 本篇将使用已经训练好的Theta1、Theta2,实现向前传播算法。下一篇将会使用反向传播算法学习出神经网络参数。
神经网络
以下为神经网络简单示意图。包含输入层、隐藏层、输出层。下面训练数据中给到的数据中将会是:输入层400,隐藏层25,输出层10;
训练数据部分手写数字展示
手写数字的数据:ex3data1.mat,(https://github.com/peedeep/Coursera/blob/master/ex3/ex3data1.mat)复制链接下载,是一个(m*n) m=5000, n=400矩阵数据,m表示训练数据样本数,n表示每个数据的特征维度。且提供了displayData.m 函数来对训练数据中随机10*10=100张手写数字显示:
逻辑函数
sigmoid.m 函数将所有实数映射到(0, 1)范围。
%% Sigmoid function
function g = sigmoid(z)
g = zeros(size(z));
g = 1.0 ./ (1.0 + exp(-z));
endfunction
神经网络多分类结果预测
predict.m 用于对输入数据进行预测,预测值最大的即为正确的分类。
%% Neural network prediction function
function p = predict(Theta1, Theta2, X)
m = size(X, 1);
k = size(Theta2, 1);
p = zeros(m, 1);
X = [ones(m, 1) X];
a2 = sigmoid(X * Theta1');
a2 = [ones(m, 1) a2]
a3 = sigmoid(a2 * Theta2');
[a, p] = max(a3, [], 2);
endfunction
ex3weights.mat数据可以在这里进行下载,加载数据后将会得到已经训练好的Theta1、Theta2,以及训练数据(X, y)。
结果预测,准确率达到97.52%
%% =========== 2.Loading Pameters ============
load('ex3weights.mat');
pred = predict(Theta1, Theta2, X);
fprintf('Train data Accuracy: %f\\n', mean(double(pred == y)) * 100);
閱讀更多 無名開發者 的文章