這次為大家詳細展示一個利用卷積神經網絡實現圖片自動分類的例程。
神經網絡的優點:自動從數據中學習經驗知識,無需複雜的模型和算法。
缺點:有監督學習,需要大量的帶標籤數據;參數量太少時容易過擬合,泛化能力差,參數量太大時訓練收斂很慢(有可能需要幾個月到幾年)。
為了克服上述缺點,人們發掘了各種計算資源,包括多核CPU、GPU、DSP、ASIC、FPGA,甚至使用模擬電路。
使用CPU實現卷積神經網絡比較方便調試,但性能太差,一般人們都選用更快的GPU實現。目前開源的框架大多都支持GPU,如伯克利大學Caffe和Google Convnet。
微軟在2015年2月宣佈使用Stratix V完成了CNN加速器,處理 CIFAR10 圖片速度可達每秒2300多張。
這裡我們也使用CIFAR10圖片數據,在Cyclone V板子上跑一個卷積神經網絡CNN demo。由於板子上計算資源太少(DSP Slice只有80多個),實現完整的網絡不太現實,只能在FPGA上實現基本計算單元,然後由HPS統一調度。性能預期不會太高,後面給出。
CIFAR10圖片都是什麼呢?先來張圖!
有興趣的朋友可以到官網下載(CIFAR10官網)。上面提到過,CNN是有監督學習系統,需要大量帶label的數據,CIFAR10就是這樣一個開放的數據庫,提供了60000張不同類別的圖片,分為10個類(如上圖左側所示),每個類別有600張圖。這個數據集不算特別大,適合在嵌入式平臺上實現。而更大的數據集有ImageNet-1000(ImageNet官網),擁有120多萬張高清無碼大圖,我下載到硬盤,佔用了近200GB空間(只能忍痛將其他rmvb和avi刪掉了)!
有朋友會問,不用這些數據行不行,我們的智能手機裡面照片能不能用於CNN做訓練?
答案是可以的,只是你的數據集很不“均勻”,採樣不夠“完備”,訓練出的模型是真實模型的“有偏估計”,而上述兩個數據集經過了種種考驗,已經是學術界公認的優質數據集,一年一度的ILSVRC比賽就採用了這些數據集。
說完數據,再說模型。先來看一張經典的CNN結構:
這是世界上第一個將CNN實用化的例子,實現了手寫體字母自動識別。在這個CNN模型中,可以看到輸入是一張32 x 32的二維圖像,經過卷積層(Convolution)、下采樣層(Subsampling,也稱Pooling)、全連接層(Full Connection,也稱Inner Product)後,得到一組概率密度,我們選其中概率最大的元素作為該模型對輸入圖像的分類結果。所以實現CNN時,只需要實現三種基本算法:卷積、下采樣、矩陣乘。除此之外,每層輸出都可選擇是否經過非線性變換,常用的非線性變換有ReLU和Sigmoid,前者計算較為簡單,使用較為廣泛。
Caffe框架中提供了專門為cifar10數據定製的模型,是用proto格式寫的,我們的demo也基於這個模型。內容如下:
- name: "CIFAR10_quick_test"
- input: "data"
- input_dim: 1
- input_dim: 3
- input_dim: 32
- input_dim: 32
- layers {
- name: "conv1"
- type: CONVOLUTION
- bottom: "data"
- top: "conv1"
- blobs_lr: 1
- blobs_lr: 2
- convolution_param {
- num_output: 32
- pad: 2
- kernel_size: 5
- stride: 1
- }
- }
- layers {
- name: "pool1"
- type: POOLING
- bottom: "conv1"
- top: "pool1"
- pooling_param {
- pool: MAX
- kernel_size: 3
- stride: 2
- }
- }
- layers {
- name: "relu1"
- type: RELU
- bottom: "pool1"
- top: "pool1"
- }
- layers {
- name: "conv2"
- type: CONVOLUTION
- bottom: "pool1"
- top: "conv2"
- blobs_lr: 1
- blobs_lr: 2
- convolution_param {
- num_output: 32
- pad: 2
- kernel_size: 5
- stride: 1
- }
- }
- layers {
- name: "relu2"
- type: RELU
- bottom: "conv2"
- top: "conv2"
- }
- layers {
- name: "pool2"
- type: POOLING
- bottom: "conv2"
- top: "pool2"
- pooling_param {
- pool: AVE
- kernel_size: 3
- stride: 2
- }
- }
- layers {
- name: "conv3"
- type: CONVOLUTION
- bottom: "pool2"
- top: "conv3"
- blobs_lr: 1
- blobs_lr: 2
- convolution_param {
- num_output: 64
- pad: 2
- kernel_size: 5
- stride: 1
- }
- }
- layers {
- name: "relu3"
- type: RELU
- bottom: "conv3"
- top: "conv3"
- }
- layers {
- name: "pool3"
- type: POOLING
- bottom: "conv3"
- top: "pool3"
- pooling_param {
- pool: AVE
- kernel_size: 3
- stride: 2
- }
- }
- layers {
- name: "ip1"
- type: INNER_PRODUCT
- bottom: "pool3"
- top: "ip1"
- blobs_lr: 1
- blobs_lr: 2
- inner_product_param {
- num_output: 64
- }
- }
- layers {
- name: "ip2"
- type: INNER_PRODUCT
- bottom: "ip1"
- top: "ip2"
- blobs_lr: 1
- blobs_lr: 2
- inner_product_param {
- num_output: 10
- }
- }
- layers {
- name: "prob"
- type: SOFTMAX
- bottom: "ip2"
- top: "prob"
- }
複製代碼
可見,上述模型經過了3個卷積層(conv1, conv2, conv3),每個卷積層後面都跟著下采樣層(pool1, pool2, pool3),之後有兩個全連接層(ip1, ip2),最後一層prob為SOFTMAX分類層,是計算概率密度的,這裡我們不需要關心。
下面三張圖分別統計了CNN模型各層的參數量、數據量和計算量。
可以看出,卷積層的參數量很少,但數據量很大;全連接層剛好相反,參數量較大,但數據量很少。
通過計算量統計發現conv2計算量最大,其次是conv3和conv1。全連接層的計算量相對卷積層較小,但不可忽略。其他層(pool1, pool2以及各級relu)由於計算量太小,本設計中沒有將其實現為Open CL kernel,而是直接CPU端實現。
綜上所述,我們重點實現兩個算法:卷積和矩陣乘,分別對應卷積層、全連接層的實現。
在DE1-SOC上我利用了友晶提供的Open CL BSP,支持C語言開發FPGA。
卷積層計算kernel函數如下:
- __attribute__((num_compute_units(4)))
- __kernel
- void conv(__global float * a, __global float * b, __global float * c, const int M, const int N, const int K)
- {
- int gx = get_global_id(0);
- int gy = get_global_id(1);
- float tmp=0.0f;
- for(int x = 0; x < K; x ++)
- {
- for(int y = 0; y < K; y ++)
- {
- tmp += a[(gx + x) * M + (gy + y)] * b[x * K + y];
- }
- }
複製代碼
全連接層計算採用矩陣乘實現,kernel函數如下:
- __attribute__((num_compute_units(4)))
- __kernel
- void gemm(__global float * a, __global float * b, __global float * c, const int M, const int N, const int K)
- {
- int gx = get_global_id(0);
- int gy = get_global_id(1);
- int sy = get_global_size(1);
- int sx = get_global_size(0);
- int s = sx * sy;
- for(int x = gx; x < M; x += sx)
- {
- for(int y = gy; y < N; y += sy)
- {
- float tmp=0.0f;
- for(int z = 0; z < K; z++)
- {
- tmp += a[z * M + x] * b[y * K + z];
- }
- c[y * M + x] = tmp;
- }
- }
- }
複製代碼
編譯kernel函數需要使用Altera SDK for OpenCL,我用的版本是14.0.0.200,申請了兩個月的license。編譯使用命令行aoc,得到*.aocx文件。
Open CL編譯輸出報告中給出了資源佔用情況:
- +--------------------------------------------------------------------+
- ; Estimated Resource Usage Summary ;
- +----------------------------------------+---------------------------+
- ; Resource + Usage ;
- +----------------------------------------+---------------------------+
- ; Logic utilization ; 83% ;
- ; Dedicated logic registers ; 46% ;
- ; Memory blocks ; 57% ;
- ; DSP blocks ; 25% ;
- +----------------------------------------+---------------------------;
複製代碼
可見,邏輯資源、存儲器資源消耗較為明顯,而DSP資源並未用盡,說明還有優化的空間。
編譯主程序需要使用SoCEDS,我用的版本為14.0.2.274,也是命令行方式,在工程目錄下執行make,結束後得到可執行文件cnn。
將這兩個文件拷貝到SD卡,按照前面的博客對板子進行設置,將CNN的模型、CIFAR10數據也拷貝到SD卡中,板子上電,mount SD卡到/mnt,執行cnn,得到輸出如下:
Please input the number of images(1~100):100
- Loading data...OK!
- Constructing CNN...OK!
- Begin calculation...Elapsed Time = 141.861 s.
- Real Label = 3(cat), Calc Label = 3(cat), error count = 0
- Real Label = 8(ship), Calc Label = 8(ship), error count = 0
- Real Label = 8(ship), Calc Label = 8(ship), error count = 0
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 0
- Real Label = 6(frog), Calc Label = 6(frog), error count = 0
- Real Label = 6(frog), Calc Label = 6(frog), error count = 0
- Real Label = 1(automobile), Calc Label = 1(automobile), error count = 0
- Real Label = 6(frog), Calc Label = 6(frog), error count = 0
- Real Label = 3(cat), Calc Label = 3(cat), error count = 0
- Real Label = 1(automobile), Calc Label = 1(automobile), error count = 0
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 0
- Real Label = 9(truck), Calc Label = 9(truck), error count = 0
- Real Label = 5(dog), Calc Label = 5(dog), error count = 0
- Real Label = 7(horse), Calc Label = 7(horse), error count = 0
- Real Label = 9(truck), Calc Label = 9(truck), error count = 0
- Real Label = 8(ship), Calc Label = 8(ship), error count = 0
- Real Label = 5(dog), Calc Label = 5(dog), error count = 0
- Real Label = 7(horse), Calc Label = 7(horse), error count = 0
- Real Label = 8(ship), Calc Label = 8(ship), error count = 0
- Real Label = 6(frog), Calc Label = 6(frog), error count = 0
- Real Label = 7(horse), Calc Label = 7(horse), error count = 0
- Real Label = 0(airplane), Calc Label = 2(bird), error count = 1
- Real Label = 4(deer), Calc Label = 4(deer), error count = 1
- Real Label = 9(truck), Calc Label = 9(truck), error count = 1
- Real Label = 5(dog), Calc Label = 4(deer), error count = 2
- Real Label = 2(bird), Calc Label = 3(cat), error count = 3
- Real Label = 4(deer), Calc Label = 4(deer), error count = 3
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 3
- Real Label = 9(truck), Calc Label = 9(truck), error count = 3
- Real Label = 6(frog), Calc Label = 6(frog), error count = 3
- Real Label = 6(frog), Calc Label = 6(frog), error count = 3
- Real Label = 5(dog), Calc Label = 5(dog), error count = 3
- Real Label = 4(deer), Calc Label = 4(deer), error count = 3
- Real Label = 5(dog), Calc Label = 5(dog), error count = 3
- Real Label = 9(truck), Calc Label = 9(truck), error count = 3
- Real Label = 2(bird), Calc Label = 3(cat), error count = 4
- Real Label = 4(deer), Calc Label = 7(horse), error count = 5
- Real Label = 1(automobile), Calc Label = 9(truck), error count = 6
- Real Label = 9(truck), Calc Label = 9(truck), error count = 6
- Real Label = 5(dog), Calc Label = 5(dog), error count = 6
- Real Label = 4(deer), Calc Label = 4(deer), error count = 6
- Real Label = 6(frog), Calc Label = 6(frog), error count = 6
- Real Label = 5(dog), Calc Label = 5(dog), error count = 6
- Real Label = 6(frog), Calc Label = 6(frog), error count = 6
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 6
- Real Label = 9(truck), Calc Label = 9(truck), error count = 6
- Real Label = 3(cat), Calc Label = 5(dog), error count = 7
- Real Label = 9(truck), Calc Label = 9(truck), error count = 7
- Real Label = 7(horse), Calc Label = 7(horse), error count = 7
- Real Label = 6(frog), Calc Label = 6(frog), error count = 7
- Real Label = 9(truck), Calc Label = 9(truck), error count = 7
- Real Label = 8(ship), Calc Label = 8(ship), error count = 7
- Real Label = 0(airplane), Calc Label = 2(bird), error count = 8
- Real Label = 3(cat), Calc Label = 3(cat), error count = 8
- Real Label = 8(ship), Calc Label = 8(ship), error count = 8
- Real Label = 8(ship), Calc Label = 8(ship), error count = 8
- Real Label = 7(horse), Calc Label = 7(horse), error count = 8
- Real Label = 7(horse), Calc Label = 7(horse), error count = 8
- Real Label = 4(deer), Calc Label = 3(cat), error count = 9
- Real Label = 6(frog), Calc Label = 3(cat), error count = 10
- Real Label = 7(horse), Calc Label = 7(horse), error count = 10
- Real Label = 3(cat), Calc Label = 5(dog), error count = 11
- Real Label = 6(frog), Calc Label = 6(frog), error count = 11
- Real Label = 3(cat), Calc Label = 3(cat), error count = 11
- Real Label = 6(frog), Calc Label = 6(frog), error count = 11
- Real Label = 2(bird), Calc Label = 2(bird), error count = 11
- Real Label = 1(automobile), Calc Label = 1(automobile), error count = 11
- Real Label = 2(bird), Calc Label = 2(bird), error count = 11
- Real Label = 3(cat), Calc Label = 3(cat), error count = 11
- Real Label = 7(horse), Calc Label = 9(truck), error count = 12
- Real Label = 2(bird), Calc Label = 2(bird), error count = 12
- Real Label = 6(frog), Calc Label = 6(frog), error count = 12
- Real Label = 8(ship), Calc Label = 8(ship), error count = 12
- Real Label = 8(ship), Calc Label = 8(ship), error count = 12
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 12
- Real Label = 2(bird), Calc Label = 2(bird), error count = 12
- Real Label = 9(truck), Calc Label = 0(airplane), error count = 13
- Real Label = 3(cat), Calc Label = 3(cat), error count = 13
- Real Label = 3(cat), Calc Label = 2(bird), error count = 14
- Real Label = 8(ship), Calc Label = 8(ship), error count = 14
- Real Label = 8(ship), Calc Label = 8(ship), error count = 14
- Real Label = 1(automobile), Calc Label = 1(automobile), error count = 14
- Real Label = 1(automobile), Calc Label = 1(automobile), error count = 14
- Real Label = 7(horse), Calc Label = 7(horse), error count = 14
- Real Label = 2(bird), Calc Label = 2(bird), error count = 14
- Real Label = 5(dog), Calc Label = 7(horse), error count = 15
- Real Label = 2(bird), Calc Label = 2(bird), error count = 15
- Real Label = 7(horse), Calc Label = 7(horse), error count = 15
- Real Label = 8(ship), Calc Label = 8(ship), error count = 15
- Real Label = 9(truck), Calc Label = 9(truck), error count = 15
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 15
- Real Label = 3(cat), Calc Label = 4(deer), error count = 16
- Real Label = 8(ship), Calc Label = 8(ship), error count = 16
- Real Label = 6(frog), Calc Label = 6(frog), error count = 16
- Real Label = 4(deer), Calc Label = 4(deer), error count = 16
- Real Label = 6(frog), Calc Label = 6(frog), error count = 16
- Real Label = 6(frog), Calc Label = 6(frog), error count = 16
- Real Label = 0(airplane), Calc Label = 2(bird), error count = 17
- Real Label = 0(airplane), Calc Label = 0(airplane), error count = 17
- Real Label = 7(horse), Calc Label = 7(horse), error count = 17
- Classify Score = 83 %.
上面的執行流程是這樣的,首先輸入測試樣本數目(1到100),由於DE1板子FPGA端SDRAM容量較小,難以加載全部測試數據(10000張圖片),故每次最多裝入100張圖片。之後載入數據到HPS內存,然後開始構建CNN模型,構建過程中也實現了Open CL的初始化。構建完畢,將輸入圖像依次通過CNN,得到一系列分類結果,與標籤進行對比,統計錯誤分類個數,計算分類準確率。
經過測試,分類準確率達到83%,與Caffe測試結果一致。
經過以上測試,可以得到結論:
(1)使用Open CL可以很方便地移植高級語言編寫的算法;
(2)CNN在移植過程中需要考慮實際硬件,定製合適的模型和數據;
(3)Cyclone 5邏輯資源較少(85K,Open CL kernel佔用了83%),如果希望進一步提高計算速度,一方面可以選用高性能器件(如Stratix V、Arria 10),另一方面可以使用RTL自己搭建計算系統。
以上圖文內容均是EEWORLD論壇網友
zhaoyongke原創,在此感謝。歡迎微博@EEWORLD
如果你也寫過此類原創乾貨請關注微信訂閱號(ID:eeworldbbs,將你的原創發至:[email protected],一經入選,我們將幫你登上頭條!
與更多行業內網友進行交流請登陸EEWORLD論壇。
閱讀更多 電子工程世界 的文章