機器學習:使用TorchFusion和PyTorch簡單介紹條件GAN

人類非常善於識別事物,也善於創造新事物。很長時間以來,我們一直在教授計算機,以模仿人類識別事物的能力,但人工智能系統很長時間都沒有具備創造新事物的能力。直到2014年,伊恩•古德費勒(Ian Goodfellow)才發明了生成對抗網絡。在這篇文章中,我們將介紹生成對抗網絡的基本概述,並使用它們生成特定數字的圖像。

生成對抗網絡概述

假設你是一位藝術家,試圖畫出一幅非常逼真的奧巴馬畫像,這會讓法官誤以為這是一幅真實的畫。當你第一次這麼做的時候,裁判很容易發現你的照片是假的,然後你反覆嘗試,直到裁判被愚弄,以為照片是真的。生成對抗網絡就是這樣工作的,它由兩個模型組成,

一種繪製圖像的生成器和一種試圖區分真實圖像和由生成器繪製圖像的判別器。

從某種意義上講,兩者都是相互競爭的,生成器被訓練用來愚弄判別器,而判別器被訓練用來正確區分哪些圖像是真實的,哪些是生成的。最後,生成器將變得如此完美,以至於鑑別器將無法區分真實圖像和生成圖像。

下面是GAN生成器創建的示例。

機器學習:使用TorchFusion和PyTorch簡單介紹條件GAN

GAN有兩個通用類,即隨機生成任何類圖像的無條件GAN和生成特定類的條件GAN。在本教程中,我們將使用條件gans,因為它們允許我們指定我們想要生成的內容。

Python實現

Tranining GAN通常很複雜,但是由於基於PyTorch的研究框架Torchfusion,這個過程將非常簡單且非常簡單。

通過PyPi安裝Torchfusion

pip3 install torchfusion

安裝PyTorch

如果您還沒有安裝PyTorch,請訪問pytorch.org以獲取PyTorch的最新安裝二進制文件。

接下來,導入幾個Python需要的庫

from torchfusion.gan.learners import *

from torchfusion.gan.applications import StandardGenerator,StandardProjectionDiscriminator

from torch.optim import Adam

from torchfusion.datasets import mnist_loader

import torch.cuda as cuda

import torch.nn as nn

import torch

from torch.distributions import Normal

定義生成器網絡和判別器,Python如下:

G = StandardGenerator(output_size=(1,32,32),latent_size=128)

D = StandardProjectionDiscriminator(input_size=(1,32,32),apply_sigmoid=False)

if cuda.is_available():

G = nn.DataParallel(G.cuda())

D = nn.DataParallel(D.cuda())

在上文中,我們將要生成的圖像的分辨率指定為1 x 32 x 32。

為Generator和Discriminator模型設置優化器

g_optim = Adam(G.parameters(),lr=0.0002,betas=(0.5,0.999))

d_optim = Adam(D.parameters(),lr=0.0002,betas=(0.5,0.999))

現在我們需要加載一個數據集,我們將嘗試從中提取樣本。在這種情況下,我們將使用MNIST。

dataset = mnist_loader(size=32,batch_size=64)

下面在我們創造一個Learner,torchfusion融合有各種各樣的Learner,他們高度專業化,有不同的目的。

learner = RAvgStandardGanLearner(G,D)

現在,我們可以調用訓練函數來訓練這兩個模型

if __name__ == "__main__":

learner.train(dataset,gen_optimizer=g_optim,disc_optimizer=d_optim,save_outputs_interval=500,model_dir="./fashion-gan",latent_size=128,num_epochs=50,batch_log=False)

通過將save_outputs_interval指定為500,learner將在每500次batch迭代後顯示樣本生成的輸出。

這是完整的Python代碼

from torchfusion.gan.learners import *

from torchfusion.gan.applications import StandardGenerator,StandardProjectionDiscriminator

from torch.optim import Adam

from torchfusion.datasets import mnist_loader

import torch.cuda as cuda

import torch.nn as nn

import torch

from torch.distributions import Normal

G = StandardGenerator(output_size=(1,32,32),latent_size=128,num_classes=10)

D = StandardProjectionDiscriminator(input_size=(1,32,32),apply_sigmoid=False,num_classes=10)

if cuda.is_available():

G = nn.DataParallel(G.cuda())

D = nn.DataParallel(D.cuda())

g_optim = Adam(G.parameters(),lr=0.0002,betas=(0.5,0.999))

d_optim = Adam(D.parameters(),lr=0.0002,betas=(0.5,0.999))

dataset = mnist_loader(size=32,batch_size=64)

learner = RAvgStandardGanLearner(G,D)

if __name__ == "__main__":

learner.train(dataset,num_classes=10,gen_optimizer=g_optim,disc_optimizer=d_optim,save_outputs_interval=500,model_dir="./MNIST-gan",latent_size=128,num_epochs=50,batch_log=False)

經過20個訓練epochs後,生成下圖:

機器學習:使用TorchFusion和PyTorch簡單介紹條件GAN

現在到最激動人心的部分,使用訓練有素的模型,您可以輕鬆生成特定數字的新圖像。

在下面的代碼中,我們生成了一個數字6的新圖像,您可以指定0到9之間的任何數字,Python代碼如下:

from torchfusion.gan.learners import *

from torchfusion.gan.applications import StandardGenerator

import torch.cuda as cuda

import torch.nn as nn

from torchvision.utils import save_image, make_grid

import torch

from torch.distributions import Normal

import numpy as np

import matplotlib.pyplot as plt

G = StandardGenerator(output_size=(1,32,32),latent_size=128,num_classes=10)

if cuda.is_available():

G = nn.DataParallel(G.cuda())

learner = RAvgHingeGanLearner(G,None)

learner.load_generator("gen_model_17.pth")

if __name__ == "__main__":

"Define an instance of the normal distribution"

dist = Normal(0,1)

#Get a sample latent vector from the distribution

latent_vector = dist.sample((1,128))

#Define the class of the image you want to generate

label = torch.LongTensor(1).fill_(6)

#Run inference

image = learner.predict([latent_vector,label])

images = make_grid(image.cpu().data, normalize=True)

images = np.transpose(images.numpy(), (1, 2, 0))

plt.axis("off")

plt.imshow(images)

plt.show()

結果:

機器學習:使用TorchFusion和PyTorch簡單介紹條件GAN

Generative Adversarial Networks是一個令人興奮的研究領域,通過優化GAN算法的優化實現,Torchfusion使其變得非常簡單。


分享到:


相關文章: