是时候通过springboot把你的机器学习项目通过API形式开放出来了

引言

最近在写毕设,也在做一些机器学习相关的科研。

然后有一个想法一直在我心头萦绕,我们的科研和应用之间还存在着某些距离,就比如说我们实现了某种分类的网络,如果要将其作为应用开放出来,这个常常是比较困难的,往往需要专业人士才能够理解和二次使用它。或者更确切的说,它的包装程度还不够。

所以最近在想着如何将这些能力开放出来,于是就有了本节内容。

本节举一个最简单的例子,在本地IDEA运行springboot项目,通过网页访问http://localhost:8080实现网络训练,待训练结束返回训练结束标志。这个网络是deeplearning4j项目中的一个分类的例子:MLPClassifierMoon。

项目地址:https://github.com/xiaozhch5/spring-guides/tree/master/springboot-dl4j

项目创建

大家可以直接在我的github项目地址下载或者按照下面步骤进行。

首先通过IDEA创建一个springboot项目。

是时候通过springboot把你的机器学习项目通过API形式开放出来了

是时候通过springboot把你的机器学习项目通过API形式开放出来了

是时候通过springboot把你的机器学习项目通过API形式开放出来了

POM文件

pom.xml文件如下

<code>
<project> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelversion>4.0.0/<modelversion>
<parent>
<groupid>org.springframework.boot/<groupid>
<artifactid>spring-boot-starter-parent/<artifactid>
<version>2.2.5.RELEASE/<version>
<relativepath>
/<parent>
<groupid>com.weilian/<groupid>
<artifactid>springboot-dl4j/<artifactid>

<version>0.0.1-SNAPSHOT/<version>
<name>springboot-dl4j/<name>
<description>Demo project for Spring Boot/<description>

<properties>
<java.version>1.8/<java.version>


<nd4j.backend>nd4j-native-platform/<nd4j.backend>

<project.build.sourceencoding>UTF-8/<project.build.sourceencoding>
<shadedclassifier>bin/<shadedclassifier>


<java.version>1.8/<java.version>
<nd4j.version>1.0.0-beta6/<nd4j.version>
<dl4j.version>1.0.0-beta6/<dl4j.version>
<datavec.version>1.0.0-beta6/<datavec.version>
<arbiter.version>1.0.0-beta6/<arbiter.version>
<rl4j.version>1.0.0-beta6/<rl4j.version>


<scala.binary.version>2.11/<scala.binary.version>
<spark.version>2.4.3/<spark.version>

<hadoop.version>2.2.0/<hadoop.version>
<guava.version>19.0/<guava.version>
<logback.version>1.1.7/<logback.version>
<jfreechart.version>1.0.13/<jfreechart.version>
<jcommon.version>1.0.23/<jcommon.version>
<maven-compiler-plugin.version>3.6.1/<maven-compiler-plugin.version>
<maven-shade-plugin.version>2.4.3/<maven-shade-plugin.version>
<exec-maven-plugin.version>1.4.0/<exec-maven-plugin.version>
<maven.minimum.version>3.3.1/<maven.minimum.version>
<javafx.version>2.2.3/<javafx.version>

<aws.sdk.version>1.11.109/<aws.sdk.version>
<jackson.version>2.5.1/<jackson.version>
<scala.plugin.version>3.2.2/<scala.plugin.version>
/<properties>


<dependencies>
<dependency>
<groupid>org.springframework.boot/<groupid>

<artifactid>spring-boot-starter-web/<artifactid>
/<dependency>

<dependency>
<groupid>org.projectlombok/<groupid>
<artifactid>lombok/<artifactid>
<optional>true/<optional>
/<dependency>
<dependency>
<groupid>org.springframework.boot/<groupid>
<artifactid>spring-boot-starter-test/<artifactid>
<scope>test/<scope>
<exclusions>
<exclusion>
<groupid>org.junit.vintage/<groupid>
<artifactid>junit-vintage-engine/<artifactid>
/<exclusion>
/<exclusions>
/<dependency>




<dependency>
<groupid>org.nd4j/<groupid>
<artifactid>${nd4j.backend}/<artifactid>
<version>${nd4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-nlp/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>

<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-zoo/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-ui/<artifactid>
<version>${dl4j.version}/<version>

/<dependency>


<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>deeplearning4j-parallel-wrapper/<artifactid>
<version>${dl4j.version}/<version>
/<dependency>


<dependency>
<groupid>org.datavec/<groupid>
<artifactid>datavec-hadoop/<artifactid>
<version>${datavec.version}/<version>
/<dependency>
<dependency>
<groupid>org.apache.hadoop/<groupid>
<artifactid>hadoop-common/<artifactid>
<version>${hadoop.version}/<version>
<exclusions>
<exclusion>
<groupid>jdk.tools/<groupid>
<artifactid>jdk.tools/<artifactid>
/<exclusion>
<exclusion>
<groupid>log4j/<groupid>
<artifactid>log4j/<artifactid>
/<exclusion>
<exclusion>
<groupid>org.slf4j/<groupid>
<artifactid>slf4j-log4j12/<artifactid>
/<exclusion>
/<exclusions>
/<dependency>



<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>arbiter-deeplearning4j/<artifactid>
<version>${arbiter.version}/<version>
/<dependency>
<dependency>
<groupid>org.deeplearning4j/<groupid>
<artifactid>arbiter-ui/<artifactid>
<version>${arbiter.version}/<version>
/<dependency>



<dependency>
<artifactid>datavec-data-codec/<artifactid>
<groupid>org.datavec/<groupid>
<version>${datavec.version}/<version>
/<dependency>


<dependency>
<groupid>jfree/<groupid>
<artifactid>jfreechart/<artifactid>
<version>${jfreechart.version}/<version>
/<dependency>
<dependency>
<groupid>org.jfree/<groupid>
<artifactid>jcommon/<artifactid>
<version>${jcommon.version}/<version>
/<dependency>


<dependency>
<groupid>org.apache.httpcomponents/<groupid>
<artifactid>httpclient/<artifactid>
<version>4.3.6/<version>
/<dependency>
<dependency>
<groupid>org.deeplearning4j.examples/<groupid>
<artifactid>shared-utilities/<artifactid>

/<dependency>

<dependency>
<groupid>ch.qos.logback/<groupid>
<artifactid>logback-classic/<artifactid>
<version>${logback.version}/<version>
/<dependency>



/<dependencies>

<build>
<plugins>
<plugin>
<groupid>org.springframework.boot/<groupid>

<artifactid>spring-boot-maven-plugin/<artifactid>
/<plugin>
/<plugins>
/<build>

/<project>/<code>

controller文件创建

创建一个controller类,com/weilian/springbootdl4j/controller/Dl4jController.java,代码如下:

<code>package com.weilian.springbootdl4j.controller;


import com.weilian.springbootdl4j.service.MLPClassifierMoon;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class Dl4jController {

@GetMapping("/")
public String dl4jResult() throws Exception{
return new MLPClassifierMoon().run();
}

}/<code>

service文件创建

创建com/weilian/springbootdl4j/service/DownloaderUtility.java,代码如下:

<code>package com.weilian.springbootdl4j.service;

import org.apache.commons.io.FilenameUtils;
import org.nd4j.resources.Downloader;

import java.io.File;
import java.net.URL;


public enum DownloaderUtility {

/**
Skymind datavec resources stored under AZURE_BLOB_URL/datavec-examples
*/
BASICDATAVECEXAMPLE("BasicDataVecExample.zip", "datavec-examples", "92f87e0ceb81093ff8b49e2b4e0a5a02", "1KB"),
INPUTSPLIT("inputsplit.zip", "datavec-examples", "f316b5274bab3b0f568eded9bee1c67f", "128KB"),

IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"),
JOINEXAMPLE("JoinExample.zip", "datavec-examples", "cbd6232cf1463d68ff24807d5dd8b530", "1KB"),
/**
Skymind dl4j-examples resources stored under AZURE_BLOB_URL/dl4j-examples
*/
ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"),
ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"),
CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"),
CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"),
DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"),
LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"),
MODELIMPORT("modelimport.zip", "dl4j-examples", "411df05aace1c9ff587e430a662ce621", "3MB"),
NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"),
NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"),
PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"),
STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"),
//This download is handled a little differently since the zip is not a single directory but a bunch of stuff at the top level
BERTEXAMPLE("https://dl4jdata.blob.core.windows.net/testresources", "bert_mrpc_frozen_v1.zip", "bert-frozen-example", "7cef8bbe62e701212472f77a0361f443", "420MB"),

/**
Skymind tf-import-examples resources stored under AZURE_BLOB_URL/tf-import-examples
*/
TFIMPORTEXAMPLES("resources.zip", "tf-import-examples", "4895e40e71b17799e4d6fb75d5a22491", "3MB"),

/**
Skymind dl4j-spark example resources stored under AZURE_BLOB_URL/dl4j-spark-examples
*/
PATENTEXAMPLE("patentExample.zip", "dl4j-spark-examples", "435e2b814d866550678d2ac4d8cc5423", "10KB");

private final String BASE_URL;
private final String DATA_FOLDER;
private final String ZIP_FILE;
private final String MD5;
private final String DATA_SIZE;
private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples";

/**
* For use with resources uploaded to Azure blob storage.
*
* @param zipFile Name of zipfile. Should be a zip of a single directory with the same name
* @param dataFolder The folder to extract to under ~/dl4j-examples-data
* @param md5 of zipfile
* @param dataSize of zipfile
*/
DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) {
this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize);
}

/**
* Downloads a zip file from a base url to a specified directory under the user's home directory

*
* @param baseURL URL of file
* @param zipFile Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL
* @param dataFolder The folder to extract to under ~/dl4j-examples-data
* @param md5 of zipfile
* @param dataSize of zipfile
*/
DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) {
BASE_URL = baseURL;
DATA_FOLDER = dataFolder;
ZIP_FILE = zipFile;
MD5 = md5;
DATA_SIZE = dataSize;
}

public String Download() throws Exception {
return Download(true);
}

public String Download(boolean returnSubFolder) throws Exception {
String dataURL = BASE_URL + "/" + ZIP_FILE;
String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE);
String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER);
if (!new File(extractDir).exists())
new File(extractDir).mkdirs();
String dataPathLocal = extractDir;
if (returnSubFolder) {
String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip"));
dataPathLocal = FilenameUtils.concat(extractDir, resourceName);
}
int downloadRetries = 10;
if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) {
System.out.println("_______________________________________________________________________");
System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to \\n\\t" + dataPathLocal);
System.out.println("_______________________________________________________________________");
Downloader.downloadAndExtract("files",
new URL(dataURL),
new File(downloadPath),
new File(extractDir),
MD5,
downloadRetries);
} else {
System.out.println("_______________________________________________________________________");
System.out.println("Example data present in \\n\\t" + dataPathLocal);
System.out.println("_______________________________________________________________________");
}
return dataPathLocal;
}
}
/<code>

创建com/weilian/springbootdl4j/service/MLClassifierMoon.java,代码如下:

<code>package com.weilian.springbootdl4j.service;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.springframework.beans.factory.annotation.Autowired;

import javax.annotation.Resource;
import java.io.File;

/**
* "Moon" Data Classification Example
*
* Based on the data from Jason Baldridge:
* \thttps://github.com/jasonbaldridge/try-tf/tree/master/simdata
*
* @author Josh Patterson
* @author Alex Black (added plots)
*
*/
@SuppressWarnings("DuplicatedCode")
public class MLPClassifierMoon {

public static String dataLocalPath;

public String run() throws Exception {
int seed = 123;
double learningRate = 0.005;
int batchSize = 50;
int nEpochs = 100;

int numInputs = 2;
int numOutputs = 2;
int numHiddenNodes = 50;

dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();

//Load the training data:
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(dataLocalPath,"moon_data_train.csv")));
DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2);

//Load the test/evaluation data:
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(new File(dataLocalPath,"moon_data_eval.csv")));
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);

//log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(learningRate, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.build();


MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100)); //Print score every 100 parameter updates

model.fit( trainIter, nEpochs );

System.out.println("Evaluate model....");
Evaluation eval = model.evaluate(testIter);

//Print the evaluation statistics
System.out.println(eval.stats());


//------------------------------------------------------------------------------------
//Training is complete. Code that follows is for plotting the data & predictions only

//Plot the data
double xMin = -1.5;

double xMax = 2.5;
double yMin = -1;
double yMax = 1.5;

//Let's evaluate the predictions at every point in the x/y input space, and plot this in the background
int nPointsPerAxis = 100;
double[][] evalPoints = new double[nPointsPerAxis*nPointsPerAxis][2];
int count = 0;
for( int i=0; i<npointsperaxis> for( int j=0; j<npointsperaxis> double x = i * (xMax-xMin)/(nPointsPerAxis-1) + xMin;
double y = j * (yMax-yMin)/(nPointsPerAxis-1) + yMin;

evalPoints[count][0] = x;
evalPoints[count][1] = y;

count++;
}
}

INDArray allXYPoints = Nd4j.create(evalPoints);
INDArray predictionsAtXYPoints = model.output(allXYPoints);

//Get all of the training data in a single array, and plot it:
rr.initialize(new FileSplit(new File(dataLocalPath,"moon_data_train.csv")));
rr.reset();
int nTrainPoints = 2000;
trainIter = new RecordReaderDataSetIterator(rr,nTrainPoints,0,2);
DataSet ds = trainIter.next();
// PlotUtil.plotTrainingData(ds.getFeatures(), ds.getLabels(), allXYPoints, predictionsAtXYPoints, nPointsPerAxis);


//Get test data, run the test data through the network to generate predictions, and plot those predictions:
rrTest.initialize(new FileSplit(new File(dataLocalPath,"moon_data_eval.csv")));
rrTest.reset();
int nTestPoints = 1000;
testIter = new RecordReaderDataSetIterator(rrTest,nTestPoints,0,2);
ds = testIter.next();
INDArray testPredicted = model.output(ds.getFeatures());
// PlotUtil.plotTestData(ds.getFeatures(), ds.getLabels(), testPredicted, allXYPoints, predictionsAtXYPoints, nPointsPerAxis);

System.out.println("****************Example finished********************");

return "train finished";
}

}/<npointsperaxis>/<npointsperaxis>/<code>

IDEA运行

代码完成之后,在IDEA中运行SpringbootDl4jApplication.java类,可以看到本项目监听本地8080端口(如果没有重新配置端口的话)

是时候通过springboot把你的机器学习项目通过API形式开放出来了

本地浏览器访问:http://localhost:8080

可以看到IDEA已经在训练网络,并且控制台输出:

是时候通过springboot把你的机器学习项目通过API形式开放出来了

训练结束后,浏览器打印出train finished。

是时候通过springboot把你的机器学习项目通过API形式开放出来了

至此,我们简单实现了通过springboot接口去实现网络训练api了。

当然,通过springboot和机器学习项目结合可以实现非常多的功能,之后我会和大家继续分享相关知识哦,欢迎大家关注我哦。

是时候通过springboot把你的机器学习项目通过API形式开放出来了


分享到:


相關文章: