圖像分割算法代碼詳解|DCAN模型

圖像分割算法代碼詳解 | DCAN模型

主要講解DCAN模型中的Image Generator和Segmentation Network的代碼,以及一些重要的功能模塊代碼。

不記得DCAN論文內容的參考下面這篇文章詳解:

網絡架構:

圖像分割算法代碼詳解|DCAN模型

DCAN模型架構

代碼說明:

重要模塊一:使用Adain對齊數據的均值和方差,使得數據具有相似的分佈

公式:

圖像分割算法代碼詳解|DCAN模型

代碼:計算數據的均值和方差

# 計算均值和方差
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.data.size()
assert (len(size) == 4) # size:(N,C,W,H)
N, C = size[:2]
# 計算方差
feat_var = feat.view(N, C, -1).var(dim=2) + eps
# 計算標準差
feat_std = feat_var.sqrt().view(N, C, 1, 1)
# 計算均值
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
# 返回均值和方差
return feat_mean, feat_std

代碼:使用Adain對齊數據的均值和方差,使得數據具有相似的分佈

def adaptive_instance_normalization(content_feat, style_feat):
# 檢查content_features 和 style_features尺寸對應
assert (content_feat.data.size()[:2] == style_feat.data.size()[:2])
size = content_feat.data.size()
# 計算style_feat的均值和方差
style_mean, style_std = calc_mean_std(style_feat)
# 計算content_feat的均值和方差
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)

重要模塊二:GRAM矩陣:常用於表示圖片風格特徵

圖像分割算法代碼詳解|DCAN模型

圖像分割算法代碼詳解|DCAN模型

代碼計算GRAM:

def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
# 基於batch的矩陣乘法
gram = features.bmm(features_t) / (ch * h * w)
return gram

網絡代碼:Image Generator

class Net(nn.Module):
def __init__(self, encoder, decoder):
super(Net, self).__init__()
enc_layers = list(encoder.children())

self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
self.decoder = decoder
self.mse_loss = nn.MSELoss() # 均方損失函數

# extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
def encode_with_intermediate(self, input):
results = [input]
for i in range(4):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]

# extract relu4_1 from input image
def encode(self, input):
for i in range(4):
input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
return input


def calc_content_loss(self, input, target):
assert (input.data.size() == target.data.size())
assert (target.requires_grad is False)
# 均方損失
loss_c = self.mse_loss(input, target)
return loss_c

def calc_style_loss(self, input, target):
# print target.requires_grad
assert (input.data.size() == target.data.size())
assert (target.requires_grad is False)
# 計算input和target的GRAM矩陣
input_gram = gram_matrix(input)
target_gram = gram_matrix(target)
# 計算style loss
return self.mse_loss(input_gram, target_gram)

def forward(self, content, style):
# style_feats依次保存著從relu1_1, relu2_1, relu3_1, relu4_1提取的信息
style_feats = self.encode_with_intermediate(style)
# 將relu4_1提取的content和style信息進行均值和方差對齊
t = adain(self.encode(content), style_feats[-1])
# 將得到的特徵進行decoder生成具有source image的內容和target image的風格的圖片gt

g_t = self.decoder(Variable(t.data, requires_grad=True))
# g_t_feats 依次保存著從relu1_1, relu2_1, relu3_1, relu4_1提取的信息
g_t_feats = self.encode_with_intermediate(g_t)
# 轉換風格的圖片gt的特徵和生成gt的特徵t之間的content_loss
## 為什麼使用adain後得到的特徵t而不是content?
loss_c = self.calc_content_loss(g_t_feats[-1], Variable(t.data))
# 計算relu1_1, relu2_1, relu3_1, relu4_1提取的特徵計算g_t和style之間的style_loss的和,從每一層進行對齊
loss_s = self.calc_style_loss(g_t_feats[0], Variable(style_feats[0].data))
for i in range(1, 4):
loss_s += self.calc_style_loss(g_t_feats[i], Variable(style_feats[i].data))
return loss_c, loss_s, g_t

網絡代碼:Segmentation Network

class FCN8s_encdec(nn.Module):
def __init__(self, num_classes, pretrained=True, caffe=False, x_size=(512, 1024)):
super(FCN8s_encdec, self).__init__()
self.x_size = x_size
# 這裡的encoder和decoder代碼沒有給出
self.enc = FCN8s_enc(num_classes=num_classes, pretrained=pretrained, caffe=caffe)
self.dec = FCN8s_dec(num_classes=num_classes, pretrained=pretrained, caffe=caffe, x_size=self.x_size)

def forward(self, x, style=None):
if style is not None:
# 使用encoder代碼提取source image和target image的特徵
org_fea = self.enc(x)
sty_fea = self.enc(style)

# 對ori_fea和sty_fea的特徵進行對齊,使得生成的t具有ori_fea的內容和sty_fea的風格
t = adain(org_fea, sty_fea)
# 使t的參數可以進行優化
t = Variable(t.data, requires_grad=True)
x = self.dec(t)
return x
else:
x = self.enc(x)
x = self.dec(x)

return x

優化器與數據部分略,源代碼獲取方式,GitHub搜索論文名字,或私信發送“DCAN代碼”

相關論文




分享到:


相關文章: