如何捕獲一隻彩色卓別林?黑白照片AI上色教程很友好|哈佛大觸

方栗子 編譯自 GitHub

如何捕獲一隻彩色卓別林?黑白照片AI上色教程很友好|哈佛大觸

老照片的手動著色魔法

媽媽小時候已經有彩色照片了,不過那些照片,還是照相館的人類手動上色的。

幾十年之後,人們已經開始培育深度神經網絡,來給老照片和老電影上色了。

來自哈佛大學的Luke Melas-Kyriazi (我叫他盧克吧) ,用自己訓練的神經網絡,把卓別林變成了彩色的卓別林,清新自然。

作為一隻哈佛學霸,盧克還為鑽研機器學習的小夥伴們寫了一個基於PyTorch的教程。

雖然教程裡的模型比給卓別林用的模型要簡約一些,但效果也是不錯了。

問題是什麼?

盧克說,給黑白照片上色這個問題的難點在於,它是多模態

的——與一幅灰度圖像對應的合理彩色圖像,並不唯一。

如何捕獲一隻彩色卓別林?黑白照片AI上色教程很友好|哈佛大觸

這並不是正確示範

傳統模型需要輸入許多額外信息,來輔助上色。

而深度神經網絡,除了灰度圖像之外,不需要任何額外輸入,就可以完成上色。

在彩色圖像裡,每個像素包含三個值,即亮度飽和度以及色調

而灰度圖像,並無飽和度色調可言,只有亮度一個值。

如何捕獲一隻彩色卓別林?黑白照片AI上色教程很友好|哈佛大觸

所以,模型要用一組數據,生成另外兩足數據。換句話說,以灰度圖像為起點,推斷出對應的彩色圖像。

為了簡單,這裡只做了256 x 256像素的圖像上色。輸出的數據量則是256 x 256 x 2。

關於顏色表示,盧克用的是LAB色彩空間,它跟RGB系統包含的信息是一樣的。

但對程序猿來說,前者比較方便把亮度和其他兩項分離開來。

如何捕獲一隻彩色卓別林?黑白照片AI上色教程很友好|哈佛大觸

數據也不難獲得,盧克用了MIT Places數據集,中的一部分。內容就是校園裡的一些地標和風景。然後轉換成黑白圖像,就可以了。以下為數據搬運代碼——

1# Download and unzip (2.2GB)2!wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz3!tar -xzf testSetPlaces205_resize.tar.gz
1# Move data into training and validation directories2import os3os.makedirs('images/train/class/', exist_ok=True) # 40,000 images4os.makedirs('images/val/class/', exist_ok=True) # 1,000 images5for i, file in enumerate(os.listdir('testSet_resize')):6 if i < 1000: # first 1000 will be val7 os.rename('testSet_resize/' + file, 'images/val/class/' + file)8 else: # others will be val9 os.rename('testSet_resize/' + file, 'images/train/class/' + file)
1# Make sure the images are there2from IPython.display import Image, display3display(Image(filename='images/val/class/84b3ccd8209a4db1835988d28adfed4c.jpg'))

好用的工具有哪些?

搭建模型和訓練模型是在PyTorch裡完成的。

還用了torchvishion,這是一套在PyTorch上處理圖像和視頻的工具。

另外,scikit-learn能完成圖片在RGB和LAB色彩空間之間的轉換。

1# Download and import libraries2!pip install torch torchvision matplotlib numpy scikit-image pillow==4.1.1
1# For plotting2import numpy as np3import matplotlib.pyplot as plt4%matplotlib inline5# For conversion6from skimage.color import lab2rgb, rgb2lab, rgb2gray7from skimage import io8# For everything9import torch10import torch.nn as nn11import torch.nn.functional as F12# For our model13import torchvision.models as models14from torchvision import datasets, transforms15# For utilities16import os, shutil, time
1# Check if GPU is available2use_gpu = torch.cuda.is_available()

模型長什麼樣?

神經網絡裡面,第一部分是幾層用來提取圖像特徵

;第二部分是一些反捲積層 (Deconvolutional Layers) ,用來給那些特徵增加分辨率。

具體來說,第一部分用的是ResNet-18,這是一個圖像分類網絡,有18層,以及一些殘差連接 (Residual Connections) 。

給第一層做些修改,它就可以接受灰度圖像輸入了。然後把第6層之後的都去掉。

然後,用代碼來定義一下這個模型。

從神經網絡的第二部分 (就是那些上採樣層) 開始。

 1class ColorizationNet(nn.Module):2 def __init__(self, input_size=128):3 super(ColorizationNet, self).__init__()4 MIDLEVEL_FEATURE_SIZE = 12856 ## First half: ResNet7 resnet = models.resnet18(num_classes=365)8 # Change first conv layer to accept single-channel (grayscale) input9 resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))10 # Extract midlevel features from ResNet-gray11 self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])1213 ## Second half: Upsampling14 self.upsample = nn.Sequential( 15 nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),16 nn.BatchNorm2d(128),17 nn.ReLU(),18 nn.Upsample(scale_factor=2),19 nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),20 nn.BatchNorm2d(64),21 nn.ReLU(),22 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),23 nn.BatchNorm2d(64),24 nn.ReLU(),25 nn.Upsample(scale_factor=2),26 nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),27 nn.BatchNorm2d(32),28 nn.ReLU(),29 nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),30 nn.Upsample(scale_factor=2)31 )3233 def forward(self, input):3435 # Pass input through ResNet-gray to extract features36 midlevel_features = self.midlevel_resnet(input)3738 # Upsample to get colors39 output = self.upsample(midlevel_features)40 return output

下一步,創建模型吧。

1model = ColorizationNet()

它是怎麼訓練的?

預測每個像素的色值,用的是迴歸 (Regression) 的方法。

損失函數 (Loss Function)

所以,用了一個均方誤差 (MSE) 損失函數——讓預測的色值與參考標準 (Ground Truth) 之間的距離平方最小化。

1criterion = nn.MSELoss()

優化損失函數

這裡是用Adam Optimizer優化的。

1optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)

加載數據

用torchtext加載數據。首先定義一個專屬的數據加載器 (DataLoader) ,來完成RGB到LAB空間的轉換。

 1class GrayscaleImageFolder(datasets.ImageFolder):2 '''Custom images folder, which converts images to grayscale before loading'''3 def __getitem__(self, index):4 path, target = self.imgs[index]5 img = self.loader(path)6 if self.transform is not None:7 img_original = self.transform(img)8 img_original = np.asarray(img_original)9 img_lab = rgb2lab(img_original)10 img_lab = (img_lab + 128) / 25511 img_ab = img_lab[:, :, 1:3]12 img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()13 img_original = rgb2gray(img_original)14 img_original = torch.from_numpy(img_original).unsqueeze(0).float()15 if self.target_transform is not None:16 target = self.target_transform(target)17 return img_original, img_ab, target

再來,就是定義訓練數據驗證數據的轉換。

1# Training2train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])3train_imagefolder = GrayscaleImageFolder('images/train', train_transforms)4train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=64, shuffle=True)56# Validation7val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])8val_imagefolder = GrayscaleImageFolder('images/val' , val_transforms)9val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=64, shuffle=False)

輔助函數 (Helper Function)

訓練開始之前,要把輔助函數寫好,來追蹤訓練損失,並把圖像轉回RGB形式。

 1class AverageMeter(object):2 '''A handy class from the PyTorch ImageNet tutorial'''3 def __init__(self):4 self.reset()5 def reset(self):6 self.val, self.avg, self.sum, self.count = 0, 0, 0, 07 def update(self, val, n=1):8 self.val = val9 self.sum += val * n10 self.count += n11 self.avg = self.sum / self.count1213def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):14 '''Show/save rgb image from grayscale and ab channels15 Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''16 plt.clf() # clear matplotlib17 color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels18 color_image = color_image.transpose((1, 2, 0)) # rescale for matplotlib19 color_image[:, :, 0:1] = color_image[:, :, 0:1] * 10020 color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128 21 color_image = lab2rgb(color_image.astype(np.float64))22 grayscale_input = grayscale_input.squeeze().numpy()23 if save_path is not None and save_name is not None:24 plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')25 plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

驗證

不用反向傳播 (Back Propagation),直接用torch.no_grad() 跑模型。

 1def validate(val_loader, model, criterion, save_images, epoch):2 model.eval()34 # Prepare value counters and timers5 batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()67 end = time.time()8 already_saved_images = False9 for i, (input_gray, input_ab, target) in enumerate(val_loader):10 data_time.update(time.time() - end)1112 # Use GPU13 if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda()1415 # Run model and record loss16 output_ab = model(input_gray) # throw away class predictions17 loss = criterion(output_ab, input_ab)18 losses.update(loss.item(), input_gray.size(0))1920 # Save images to file21 if save_images and not already_saved_images:22 already_saved_images = True23 for j in range(min(len(output_ab), 10)): # save at most 5 images24 save_path = {'grayscale': 'outputs/gray/', 'colorized': 'outputs/color/'}25 save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch)26 to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)2728 # Record time to do forward passes and save images29 batch_time.update(time.time() - end)30 end = time.time()3132 # Print model accuracy -- in the code below, val refers to both value and validation33 if i % 25 == 0:34 print('Validate: [{0}/{1}]\t'35 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'36 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(37 i, len(val_loader), batch_time=batch_time, loss=losses))3839 print('Finished validation.')40 return losses.avg

訓練

用loss.backward(),用上反向傳播。寫一下訓練數據跑一遍 (one epoch) 用的函數。

 1def train(train_loader, model, criterion, optimizer, epoch):2 print('Starting training epoch {}'.format(epoch))3 model.train()45 # Prepare value counters and timers6 batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()78 end = time.time()9 for i, (input_gray, input_ab, target) in enumerate(train_loader):1011 # Use GPU if available12 if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda()1314 # Record time to load data (above)15 data_time.update(time.time() - end)1617 # Run forward pass18 output_ab = model(input_gray)19 loss = criterion(output_ab, input_ab)20 losses.update(loss.item(), input_gray.size(0))2122 # Compute gradient and optimize23 optimizer.zero_grad()24 loss.backward()25 optimizer.step()2627 # Record time to do forward and backward passes28 batch_time.update(time.time() - end)29 end = time.time()3031 # Print model accuracy -- in the code below, val refers to value, not validation32 if i % 25 == 0:33 print('Epoch: [{0}][{1}/{2}]\t'34 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'35 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'36 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(37 epoch, i, len(train_loader), batch_time=batch_time,38 data_time=data_time, loss=losses))3940 print('Finished training epoch {}'.format(epoch))

然後,定義一個訓練迴路 (Training Loop) ,跑一百遍訓練數據。從Epoch 0開始訓練。

1# Move model and loss function to GPU2if use_gpu:3 criterion = criterion.cuda()4 model = model.cuda()
1# Make folders and set parameters2os.makedirs('outputs/color', exist_ok=True)3os.makedirs('outputs/gray', exist_ok=True)4os.makedirs('checkpoints', exist_ok=True)5save_images = True6best_losses = 1e107epochs = 100
 1# Train model2for epoch in range(epochs):3 # Train for one epoch, then validate4 train(train_loader, model, criterion, optimizer, epoch)5 with torch.no_grad():6 losses = validate(val_loader, model, criterion, save_images, epoch)7 # Save checkpoint and replace old best model if current model is better8 if losses < best_losses:9 best_losses = losses10 torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))

訓練結果什麼樣?

是時候看看修煉成果了,所以,複製一下這段代碼。

 1# Show images2import matplotlib.image as mpimg3image_pairs = [('outputs/color/img-2-epoch-0.jpg', 'outputs/gray/img-2-epoch-0.jpg'),4 ('outputs/color/img-7-epoch-0.jpg', 'outputs/gray/img-7-epoch-0.jpg')]5for c, g in image_pairs:6 color = mpimg.imread(c)7 gray = mpimg.imread(g)8 f, axarr = plt.subplots(1, 2)9 f.set_size_inches(15, 15)10 axarr[0].imshow(gray, cmap='gray')11 axarr[1].imshow(color)12 axarr[0].axis('off'), axarr[1].axis('off')13 plt.show()

效果還是很自然的,雖然生成的彩色圖像不是那麼明麗。

盧克說,問題是多模態的,所以損失函數還是值得推敲。

比如,一條灰色裙子可以是藍色也可以是紅色。如果模型選擇的顏色和參考標準不同,就會受到嚴厲的懲罰。

這樣一來,模型就會選擇哪些不會被判為大錯特錯的顏色,而不太選擇非常顯眼明亮的顏色。

沒時間怎麼辦?

盧克還把一隻訓練好的AI放了出來,不想從零開始訓練的小夥伴們,也可以直接感受他的訓練成果,只要用以下代碼下載就好了。

1# Download pretrained model2!wget https://www.dropbox.com/s/kz76e7gv2ivmu8p/model-epoch-93.pth3#https://www.dropbox.com/s/9j9rvaw2fo1osyj/model-epoch-67.pth
1# Load model2pretrained = torch.load('model-epoch-93.pth', map_location=lambda storage, loc: storage)3model.load_state_dict(pretrained)
1# Validate2save_images = True3with torch.no_grad():4 validate(val_loader, model, criterion, save_images, 0)

彩色老電影?

如果想要更加有聲有色的結局,就不能繼續偷懶了。盧克希望大家沿著他精心鋪就的路,走到更遠的地方。

要替換當前的損失函數,可以參考Zhang et al. (2017):

https://richzhang.github.io/ideepcolor/

無監督學習的上色大法,可以參考Larsson et al. (2017):

http://people.cs.uchicago.edu/~larsson/color-proxy/

另外,可以做個手機應用,就像谷歌在I/O大會上發佈的著色軟件那樣。

黑白電影,也可以自己去嘗試,一幀一幀地上色。

這裡有卓別林用到的完整代碼

https://github.com/lukemelas/Automatic-Image-Colorization/

如何捕獲一隻彩色卓別林?黑白照片AI上色教程很友好|哈佛大觸

誠摯招聘

վ'ᴗ' ի 追蹤AI技術和產品新動態


分享到:


相關文章: