Unet 背景介紹

Unet 發表於 2015 年,屬於 FCN 的一種變體,想了解 FCN 可以看我的另一篇 FCN 全卷積網絡論文閱讀及代碼實現 。Unet 的初衷是為了解決生物醫學圖像方面的問題,由於效果確實很好後來也被廣泛的應用在語義分割的各個方向,比如衛星圖像分割,工業瑕疵檢測等。

Unet 跟 FCN 都是 Encoder-Decoder 結構,結構簡單但很有效。Encoder 負責特徵提取,你可以將自己熟悉的各種特徵提取網絡放在這個位置。由於在醫學方面,樣本收集較為困難,為了解決這個問題,應用了圖像增強的方法,在數據集有限的情況下獲得了不錯的精度。

Unet 網絡結構與細節

  • Encoder

如上圖,Unet 網絡結構是對稱的,形似英文字母 U 所以被稱為 Unet。整張圖都是由藍/白色框與各種顏色的箭頭組成,其中,藍/白色框表示 feature map;藍色箭頭表示 3x3 卷積,用於特徵提取;灰色箭頭表示 skip-connection,用於特徵融合;紅色箭頭表示池化 pooling,用於降低維度;綠色箭頭表示上採樣 upsample,用於恢復維度;青色箭頭表示 1x1 卷積,用於輸出結果。

可能你會問為啥是 5 層而不是 4 層或者 6 層,emmm,這應該去問本人,可能對於當時拿到的數據集來說,這個層數的表現更好,但不代表所有的數據集這個結構都適合。我們該多關注這種 Encoder-Decoder 的設計思想,具體實現則應該因數據集而異。

Encoder 由卷積操作和下采樣操作組成,文中所用的卷積結構統一為 3x3 的卷積核,padding 為 0 ,striding 為 1。沒有 padding 所以每次卷積之後 feature map 的 H 和 W 變小了,在 skip-connection 時要注意 feature map 的維度(其實也可以將 padding 設置為 1 避免維度不對應問題),pytorch 代碼:

<code>nn.Sequential(nn.Conv2d(in_channels, out_channels,  3), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True))
複製代碼/<code>

上述的兩次卷積之後是一個 stride 為 2 的 max pooling,輸出大小變為 1/2 *(H, W):


Unet 背景介紹


pytorch 代碼:

<code>nn.MaxPool2d(kernel_size=2, stride=2)
複製代碼/<code>

上面的步驟重複 5 次,最後一次沒有 max-pooling,直接將得到的 feature map 送入 Decoder。

  • Decoder

feature map 經過 Decoder 恢復原始分辨率,該過程除了卷積比較關鍵的步驟就是 upsampling 與 skip-connection。

Upsampling 上採樣常用的方式有兩種:1.FCN 中介紹的反捲積;2. 插值。 這裡介紹文中使用的插值方式。在插值實現方式中,bilinear 雙線性插值的綜合表現較好也較為常見 。

雙線性插值的計算過程沒有需要學習的參數,實際就是套公式,這裡舉個例子方便大家理解(例子介紹的是參數 align_corners 為 Fasle 的情況)。


Unet 背景介紹


例子中是將一個 2x2 的矩陣通過插值的方式得到 4x4 的矩陣,那麼將 2x2 的矩陣稱為源矩陣,4x4 的矩陣稱為目標矩陣。雙線性插值中,目標點的值是由離他最近的 4 個點的值計算得到的,我們首先介紹如何找到目標點周圍的 4 個點,以 P2 為例。


Unet 背景介紹


第一個公式,目標矩陣到源矩陣的座標映射:

為了找到那 4 個點,首先要找到目標點在源矩陣中的相對位置,上面的公式就是用來算這個的。P2 在目標矩陣中的座標是 (0, 1),對應到源矩陣中的座標就是 (-0.25, 0.25)。座標裡面居然有小數跟負數,不急我們一個一個來處理。我們知道雙線性插值是從座標周圍的 4 個點來計算該座標的值,(-0.25, 0.25) 這個點周圍的 4 個點是(-1, 0), (-1, 1), (0, 0), (0, 1)。為了找到負數座標點,我們將源矩陣擴展為下面的形式,中間紅色的部分為源矩陣。


Unet 背景介紹


我們規定 f(i, j) 表示 (i, j)座標點處的像素值,對於計算出來的對應的座標,我們統一寫成 (i+u, j+v) 的形式。那麼這時 i=-1, u=0.75, j=0, v=0.25。把這 4 個點單獨畫出來,可以看到目標點 P2 對應到源矩陣中的相對位置


Unet 背景介紹


第二個公式,也是最後一個。

f(i + u, j + v) = (1 - u) (1 - v) f(i, j) + (1 - u) v f(i, j + 1) + u (1 - v) f(i + 1, j) + u v f(i + 1, j + 1)

目標點的像素值就是周圍 4 個點像素值的加權和,明顯可以看出離得近的權值比較大例如 (0, 0) 點的權值就是 0.75x0.75,離得遠的如 (-1, 1) 權值就比較小,為 0.25*0.25,這也比較符合常理吧。把值帶入計算就可以得到 P2 點的值了,結果是 12.5 與代碼吻合上了,nice。

pytorch 裡使用 bilinear 插值:

<code>nn.Upsample(scale_factor=2, mode='bilinear')
複製代碼/<code>

CNN 網絡要想獲得好效果,skip-connection 基本必不可少。Unet 中這一關鍵步驟融合了底層信息的位置信息與深層特徵的語義信息,pytorch 代碼:

<code>torch.cat([low_layer_features, deep_layer_features], dim=1)
複製代碼/<code>

這裡需要注意的是,FCN 中深層信息與淺層信息融合是通過對應像素相加的方式,而 Unet 是通過拼接的方式。

那麼這兩者有什麼區別呢,其實 在 ResNet 與 DenseNet 中也有一樣的區別,Resnet 使用了對應值相加,DenseNet 使用了拼接。個人理解在相加的方式下,feature map 的維度沒有變化,但每個維度都包含了更多特徵,對於普通的分類任務這種不需要從 feature map 復原到原始分辨率的任務來說,這是一個高效的選擇;而拼接則保留了更多的維度/位置 信息,這使得後面的 layer 可以在淺層特徵與深層特徵自由選擇,這對語義分割任務來說更有優勢。

小結

Unet 基於 Encoder-Decoder 結構,通過拼接的方式實現特徵融合,結構簡明且穩定,如果你有語義分割的問題,尤其在樣本數據量不大的情況下,十分推薦一試。

http://kmanong.top

http://jinritoutiao.cloud


分享到:


相關文章: