DL4J入門(二):MNIST手寫數字圖像分類

在機器學習領域,MNIST手寫數字圖像識別相當於編程語言中的“Hello World”入門示例。

本文首先給出MNIST手寫數字圖像分類代碼示例,然後再分模塊介紹MNIST手寫數字圖像分類的代碼實現。

POM文件:

<code>


<project> xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelversion>4.0.0/<modelversion>
<artifactid>dl4j-examples/<artifactid>

<parent>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-examples-parent/<artifactid>
<version>1.0.0-beta6/<version>
/<parent>

<name>DeepLearning4j Examples/<name>

<repositories>
<repository>
snapshots-repo
https://oss.sonatype.org/content/repositories/snapshots
<releases>

<enabled>false/<enabled>
/<releases>
<snapshots>
<enabled>true/<enabled>
/<snapshots>
/<repository>
/<repositories>

<distributionmanagement>
<snapshotrepository>
sonatype-nexus-snapshots
<name>Sonatype Nexus snapshot repository/<name>
https://oss.sonatype.org/content/repositories/snapshots
/<snapshotrepository>
/<distributionmanagement>

<dependencymanagement>
<dependencies>
<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>nd4j-native-platform/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>
<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>nd4j-cuda-9.2-platform/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>
<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>nd4j-cuda-10.0-platform/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>
<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>nd4j-cuda-10.1-platform/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>
<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>nd4j-cuda-10.2-platform/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>
/<dependencies>
/<dependencymanagement>

<dependencies>

<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>${nd4j.backend}/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-nlp/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>

<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-zoo/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-ui/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-parallel-wrapper/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.datavec/<groupid>
<artifactid>datavec-hadoop/<artifactid>
<version>${datavec.version}/<version>
/<dependency>
<dependency>
<groupid>org.apache.hadoop/<groupid>
<artifactid>hadoop-common/<artifactid>
<version>${hadoop.version}/<version>

<exclusions>
<exclusion>
<groupid>jdk.tools/<groupid>
<artifactid>jdk.tools/<artifactid>
/<exclusion>
<exclusion>
<groupid>log4j/<groupid>
<artifactid>log4j/<artifactid>
/<exclusion>
<exclusion>
<groupid>org.slf4j/<groupid>
<artifactid>slf4j-log4j12/<artifactid>
/<exclusion>
/<exclusions>
/<dependency>



<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>arbiter-deeplearning4j/<artifactid>
<version>${arbiter.version}/<version>
/<dependency>
<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>arbiter-ui/<artifactid>
<version>${arbiter.version}/<version>
/<dependency>


<dependency>
<artifactid>datavec-data-codec/<artifactid>
<groupid>org.datavec/<groupid>
<version>${datavec.version}/<version>
/<dependency>


<dependency>
<groupid>jfree/<groupid>
<artifactid>jfreechart/<artifactid>
<version>${jfreechart.version}/<version>
/<dependency>
<dependency>
<groupid>org.jfree/<groupid>
<artifactid>jcommon/<artifactid>
<version>${jcommon.version}/<version>
/<dependency>



<dependency>
<groupid>org.apache.httpcomponents/<groupid>
<artifactid>httpclient/<artifactid>
<version>4.3.6/<version>
/<dependency>
<dependency>
<groupid>org.deeplearning4j.examples/<groupid>
<artifactid>shared-utilities/<artifactid>
<version>${project.version}/<version>
/<dependency>

<dependency>
<groupid>ch.qos.logback/<groupid>
<artifactid>logback-classic/<artifactid>
<version>${logback.version}/<version>
/<dependency>

/<dependencies>

<build>
<plugins>
<plugin>
<groupid>org.codehaus.mojo/<groupid>
<artifactid>exec-maven-plugin/<artifactid>
<version>${exec-maven-plugin.version}/<version>
<executions>
<execution>
<goals>
<goal>exec/<goal>
/<goals>
/<execution>
/<executions>
<configuration>
<executable>java/<executable>
/<configuration>
/<plugin>
<plugin>
<groupid>org.apache.maven.plugins/<groupid>
<artifactid>maven-shade-plugin/<artifactid>
<version>${maven-shade-plugin.version}/<version>
<configuration>
<shadedartifactattached>true/<shadedartifactattached>
<shadedclassifiername>${shadedClassifier}/<shadedclassifiername>
<createdependencyreducedpom>true/<createdependencyreducedpom>
<filters>
<filter>
<artifact>*:*/<artifact>

<excludes>
<exclude>org/datanucleus/**/<exclude>
<exclude>META-INF/*.SF/<exclude>
<exclude>META-INF/*.DSA/<exclude>
<exclude>META-INF/*.RSA/<exclude>
/<excludes>
/<filter>
/<filters>

/<configuration>

<executions>
<execution>
<phase>package/<phase>
<goals>
<goal>shade/<goal>
/<goals>
<configuration>
<transformers>
<transformer>
<resource>reference.conf/<resource>
/<transformer>
<transformer>
<transformer>
/<transformer>
/<transformers>
/<configuration>
/<execution>
/<executions>
/<plugin>

<plugin>
<groupid>org.apache.maven.plugins/<groupid>
<artifactid>maven-compiler-plugin/<artifactid>
<version>3.5.1/<version>
<configuration>
<source>${java.version}/<source>
<target>${java.version}/<target>
/<configuration>
/<plugin>
/<plugins>
/<build>

<profiles>
<profile>
OpenJFX
<dependencies>
<dependency>

<groupid>com.oracle/<groupid>
<artifactid>javafx/<artifactid>
<version>${javafx.version}/<version>
/<dependency>
/<dependencies>
/<profile>
/<profiles>

/<project>
/<code>

代碼文件:

<code>/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.examples.feedforward.mnist;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**A Simple Multi Layered Perceptron (MLP) applied to digit classification for
* the MNIST Dataset (http://yann.lecun.com/exdb/mnist/).

*
* This file builds one input layer and one hidden layer.
*
* The input layer has input dimension of numRows*numColumns where these variables indicate the
* number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit
* (relu) activation function. The weights for this layer are initialized by using Xavier initialization
* (https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/)
* to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer.
*
* The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights
* for this layer is also initialized using Xavier initialization. The activation function for this
* layer is a softmax, which normalizes all the 10 outputs such that the normalized sums
* add up to 1. The highest of these normalized values is picked as the predicted class.
*
*/
public class MLPMnistSingleLayerExample {

private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);

public static void main(String[] args) throws Exception {
//number of rows and columns in the input pictures
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // number of output classes
int batchSize = 128; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
int numEpochs = 15; // number of epochs to perform

//Get the DataSetIterators:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);


log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed) //include a random seed for reproducibility
// use stochastic gradient descent as an optimization algorithm
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder() //create the first, input layer with xavier initialization
.nIn(numRows * numColumns)
.nOut(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)

.weightInit(WeightInit.XAVIER)
.build())
.build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//print the score with every 1 iteration
model.setListeners(new ScoreIterationListener(1));

log.info("Train model....");
model.fit(mnistTrain, numEpochs);


log.info("Evaluate model....");
Evaluation eval = model.evaluate(mnistTest);
log.info(eval.stats());
log.info("****************Example finished********************");

}

}
/<code>

MNIST數據集介紹

MNIST數據集包含一個有6萬個樣例的訓練集和一個有1萬個樣例的測試集。其中圖像如下:

DL4J入門(二):MNIST手寫數字圖像分類

在DL4J中,可通過官方提供的類直接進行抓取:

<code>DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);/<code>

變量設置

<code>final int numRows = 28; // 矩陣的行數。
final int numColumns = 28; // 矩陣的列數。
int outputNum = 10; // 潛在結果(比如0到9的整數標籤)的數量。
int batchSize = 128; // 每一步抓取的樣例數量。
int rngSeed = 123; // 這個隨機數生成器用一個隨機種子來確保訓練時使用的初始權重維持一致。下文將會說明這一點的重要性。
int numEpochs = 15; // 一個epoch指將給定數據集全部處理一遍的週期。/<code>

神經網絡搭建及超參數設置

本文使用如圖所示網絡結構,其中只有一層隱藏層:

DL4J入門(二):MNIST手寫數字圖像分類

在使用DL4J構建神經網絡,其基礎都是NeuralNetConfiguration類,用這個類配置網絡的各項超參數和網絡層配置。

<code>MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(new DenseLayer.Builder() //create the first, input layer with xavier initialization
.nIn(numRows * numColumns)
.nOut(1000)
.activation(Activation.RE .weightInit(WeightInit.XAVIER)
.build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.build();/<code>
  • .seed(rngSeed)

該參數將一組隨機生成的權重確定為初始權重。如果一個示例運行很多次,而每次開始時都生成一組新的隨機權重,那麼神經網絡的表現(準確率和F1值)有可能會出現很大的差異,因為不同的初始權重可能會將算法導向誤差曲面上不同的局部極小值。在其他條件不變的情況下,保持相同的隨機權重可以使調整其他超參數所產生的效果表現得更加清晰。

  • .optimizationAlgo (OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

隨機梯度下降(Stochastic Gradient Descent,SGD)是一種用於優化代價函數的常見方法。要了解SGD和其他幫助實現誤差最小化的優化算法,可參考Andrew Ng的機器學習課程以及本網站術語表中對SGD的定義。

  • .iterations(1)

對一個神經網絡而言,一次迭代(iteration)指的是一個學習步驟,亦即模型權重的一次更新。神經網絡讀取數據並對其進行預測,然後根據預測的錯誤程度來修正自己的參數。因此迭代次數越多,網絡的學習步驟和學習量也越多,讓誤差更接近極小值。

  • .learningRate(0.006)

本行用於設定學習速率(learning rate),即每次迭代時對於權重的調整幅度,亦稱步幅。學習速率越高,神經網絡“翻越”整個誤差曲面的速度就越快,但也更容易錯過誤差極小點。學習速率較低時,網絡更有可能找到極小值,但速度會變得非常慢,因為每次權重調整的幅度都比較小。

  • .updater(Updater.NESTEROVS).momentum(0.9)

動量(momentum)是另一項決定優化算法向最優值收斂的速度的因素。動量影響權重調整的方向,所以在代碼中,我們將其視為一種權重的更新器(updater)。

  • .regularization(true).l2(1e-4)

正則化(regularization)是用來防止過擬合的一種方法。過擬合是指模型對訓練數據的擬合非常好,然而一旦在實際應用中遇到從未出現過的數據,運行效果就變得很不理想。

我們用L2正則化來防止個別權重對總體結果產生過大的影響。

  • .list()

函數可指定網絡中層的數量;它會將您的配置複製n次,建立分層的網絡結構。

再次提醒:如果對以上任何內容感到困惑,建議您參考Andrew Ng的機器學習課程。

  • .layer()

用於設置網絡層,0表示輸入層,1表示中間隱藏層。

隱藏層中的每個節點(上圖中的圓圈)表示MNIST數據集中一個手寫數字的一項特徵。例如,假設現在處理的數字是6,那麼一個節點可能表示圓形的邊緣,另一個節點可能表示曲線的交叉點,等等。模型的係數按照重要性大小為這些特徵賦予權重,隨後在每個隱藏層中重新相加,幫助預測當前的手寫數字是否確實為6。節點的層數更多,網絡就能處理更復雜的因素,捕捉更多細節,進而做出更準確的預測。

之所以將中間的層稱為“隱藏”層,是因為人們可以看到數據輸入神經網絡、判定結果輸出,但網絡內部的數據處理方式和原理並非一目瞭然。神經網絡模型的參數其實就是包含許多數字、計算機可以讀取的長向量。

網絡訓練

<code>        MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//print the score with every 1 iteration
model.setListeners(new ScoreIterationListener(1));

log.info("Train model....");
model.fit(mnistTrain, numEpochs);


log.info("Evaluate model....");
Evaluation eval = model.evaluate(mnistTest);
log.info(eval.stats());
log.info("****************Example finished********************");/<code>
  • 將神經網絡配置傳入MultiLayerNetwork
  • 通過調用init()方法初始化網絡
  • 通過調用setListeners()方法打印得分
  • 通過調用fit()方法將訓練數據和迭代次數作為傳入函數中
  • 通過調用evaluate()方法使用訓練好的網絡對測試數據進行測試

測試結果:

<code>========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.9724
Precision: 0.9724
Recall: 0.9721
F1 Score: 0.9722
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
966 0 1 2 0 3 5 1 2 0 | 0 = 0
0 1125 2 1 0 1 3 1 2 0 | 1 = 1
4 3 1004 5 3 1 1 7 4 0 | 2 = 2
0 0 2 992 0 3 0 6 5 2 | 3 = 3
1 0 5 0 960 0 3 2 2 9 | 4 = 4
3 1 0 8 1 863 8 1 5 2 | 5 = 5
5 3 1 0 7 7 932 0 3 0 | 6 = 6
1 10 11 3 1 1 0 992 0 9 | 7 = 7
3 1 2 9 3 6 5 5 938 2 | 8 = 8
4 8 1 12 20 2 1 6 3 952 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
o.d.e.f.m.MLPMnistSingleLayerExample - /<code>


分享到:


相關文章: