Python+Android進行TensorFlow開發

Tensorflow是Google開源的一套機器學習框架,支持GPU、CPU、Android等多種計算平臺。本文將介紹在Tensorflow在Android上的使用。

Android使用Tensorflow框架需要引入兩個文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。這兩個文件可以使用官方預編譯的文件。如果預編譯的so不滿足要求(比如不支持訓練模型中的某些操作符運算),也可以自己通過bazel編譯生成這兩個文件。將libandroid_tensorflow_inference_java.jar放在app下的libs目錄下,so文件命名為libtensorflow_jni.so放在src/main/jniLibs目錄下對應的ABI文件夾下。目錄結構如下:

Python+Android進行TensorFlow開發

Android目錄結構

同時在app的build.gradle中的dependencies模塊下添加如下配置:

<code>dependencies {
...
compile files('libs/libandroid_tensorflow_inference_java.jar')
...
}
12345/<code>

使用tensorflow框架進行機器學習分為四個步驟:

  • 構造神經網絡
  • 訓練神經網絡模型
  • 將訓練好的模型輸出為pb文件
  • ndroid上加載pb模型進行計算

前三步是模型的構造,我們通過python實現,下面給出了一個二分類的簡單模型的構造過程,首先是訓練過程:

<code># -*-coding:utf-8 -*-
from __future__ import print_function
import os
import tensorflow as tf
from numpy.random import RandomState

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""
訓練模型
"""

def train():
# 定義訓練數據集batch大小為8
batch_size = 8

# 定義神經網絡參數,參數體現出神經網絡結構,一個輸入層,一個輸出層,一個隱藏層
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")

# 定義輸入輸出格式
x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
y_ = tf.placeholder(tf.float32, shape=(None, 1))

# 定義神經網絡前向傳播過程
a = tf.matmul(x, w1)
y = tf.matmul(a, w2, name="cal_node")

# 定義交叉熵和反向傳播算法
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)

# 生成隨機訓練集
rdm = RandomState(1)
dataset_size = 128

# 定義映射關係
X = rdm.rand(dataset_size, 2)
Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

with tf.Session() as sess:
# 初始化所有參數
init_op = tf.global_variables_initializer()
sess.run(init_op)

# print sess.run(w1)
# print sess.run(w2)

STEPS = 500
for i in range(STEPS):
start = (i * batch_size) % dataset_size
end = min(start + batch_size, dataset_size)


# 訓練神經網絡,更新神經網絡參數
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})

if i % 100 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))

print(sess.run(w1))
print(sess.run(w2))

# 保存check point
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, './model/checpt')
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465/<code>

上面的代碼首先定義神經網絡,初始化訓練數據,進行500次訓練過程,並將訓練結果checkpoints保存到model文件夾下,checkpoints包含了訓練模型得到的參數信息,共生成四個相關的文件,如下圖:

Python+Android進行TensorFlow開發

由於checkpoint文件眾多,為了方便使用,我們通過下面的代碼將它們生成一個pb文件,在android上只需要這個pb文件即可使用這個訓練好的模型:

<code>"""
存儲pb模型
"""
def dump_graph_to_pb(pb_path):
with tf.Session() as sess:
check_point = tf.train.get_checkpoint_state("./model/")
if check_point:
saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
saver.restore(sess, check_point.model_checkpoint_path)
else:
raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))

graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))

with tf.gfile.GFile(pb_path, "wb") as f:
f.write(graph_def.SerializeToString())
12345678910111213141516/<code>

拿到生成的pb模型,我們可以在android上使用了。將pb文件在這main/assets下:

Python+Android進行TensorFlow開發

接下來就可以載入pb,進行計算了:

<code>public class MainActivity extends AppCompatActivity {
private Graph graph_;
private Session session_;
private AssetManager assetManager;

private static ExecutorService executorService;
private static Handler handler;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);

executorService = Executors.newFixedThreadPool(5);

// 初始化tensorflow
initTensorFlow("outmodel.pb");

// 使用tensorflow進行計算
runTensorFlow();
}
...
}
12345678910111213141516171819202122/<code>

通過如下方式載入pb模型,初始化tensorflow:

<code>private boolean initTensorFlow(String modelFile) {
assetManager = getAssets();
// 新建Graph
graph_ = new Graph();

InputStream is = null;
try {
// 讀取Assets pb文件
is = assetManager.open(modelFile);
} catch (IOException e) {
e.printStackTrace();
return false;
}

try {
// 加載pb到Graph

TensorUtil.loadGraph(is, graph_);
is.close();
} catch (IOException e) {
e.printStackTrace();
return false;
}
// 初始化session
session_ = new Session(graph_);
if (session_ == null) {
return false;
}

return true;
}
123456789101112131415161718192021222324252627282930/<code>

然後就可以使用tensorflow API進行運算了:

<code>private void runTensorFlow() {
executorService.execute(generatePredictRunnable(handler));
}

private Runnable generatePredictRunnable(Handler handler) {
return new Runnable() {
@Override
public void run() {
float[][] input = new float[1][2];

input[0][0] = 1;
input[0][1] = 2;

// 定義輸入tensor
Tensor inputTensor = Tensor.create(input);

// 指定輸入,輸出節點,運行並得到結果
Tensor resultTensor = session_.runner()
.feed("x_input", inputTensor)
.fetch("cal_node")
.run()
.get(0);

float[][] dst = new float[1][1];
resultTensor.copyTo(dst);

// 處理結果
ArrayList<float> resultList = new ArrayList<>();

for (float val : dst[0]) {
if (val != 0) {
resultList.add(val);
} else {
break;
}
}
}
};
}
1234567891011121314151617181920212223242526272829303132333435363738/<float>/<code>

上面就是通過python訓練機器學習模型,並在android平臺進行調用的完整流程。

原創作者:JackMeGo,原文鏈接:https://www.jianshu.com/p/eef4ab014a12

Python+Android進行TensorFlow開發

歡迎關注我的微信公眾號「碼農突圍」,分享Python、Java、大數據、機器學習、人工智能等技術,關注碼農技術提升•職場突圍•思維躍遷,20萬+碼農成長充電第一站,陪有夢想的你一起成長。


分享到:


相關文章: