Tensorboard+PyTorch牛刀小試

TensorBoard是Tensorflow附帶的神經網絡可視化分析工具,其功能十分強大。作為Pytorch用戶,很幸運Tensorboard同樣支持PyTorch。這裡給出個人最近使用Tensorboard時的一些經驗。

1.14版本之後的Pytorch內已經有Tensorboard接口,但是仍然需要我們手動安裝Tensorboard才可以調用:

<code>

pip

install tensorboard/<code>

Residual Block示例

我們以殘差網絡模塊為示例演示Tensorboard可視化工具,我的主要目的是為了觀察驗證複雜神經網絡的結構,下文主要展示網絡結構的可視化功能:

<code>

class

Residual

(

nn

.

Module

):

def

__init__

(

self

, ins, outs)

:

super

(Residual,

self

).__init_

_

() inner = int(outs /

2

)

self

.convBlock = nn.Sequential( nn.BatchNorm2d(ins), nn.ReLU(inplace=True), nn.Conv2d(ins, inner,

1

), nn.BatchNorm2d(inner), nn.ReLU(inplace=True), nn.Conv2d(inner, inner,

3

,

1

,

1

), nn.BatchNorm2d(inner), nn.ReLU(inplace=True), nn.Conv2d(inner, outs,

1

) )

if

ins !=

outs:

self

.skipConv = nn.Conv2d(ins, outs,

1

)

self

.ins = ins

self

.outs = outs

def

forward

(

self

, x)

: residual = x x =

self

.convBlock(x)

if

self

.ins !=

self

.

outs:

residual =

self

.skipConv(residual) x += residual

return

x/<code>

構建殘差網絡之後,我們以輸入三通道輸出12通道的殘差結構為例,給定隨機輸入dummy_input,

<code>

import

os

from

torch.utils.tensorboard import SummaryWriter

path

=

os.getcwd()+'\\runs'

writer

=

SummaryWriter(path+'\\experiment1')

t

=

Residual(3,12)

dummy_input

=

torch.randn(10, 3, 128, 128)

dummy_input)

/<code>

利用Tensorboard進行可視化

Tensorboard需要我們從命令行另外調用,

<code>tensorboard /<code>

注意這裡的path需要和你上步中summarywriter的路徑path保持一致。

Ps:這裡的步驟是在windows平臺下完成的,記得在命令行內輸入上述命令是不要加雙引號。隨後利用自帶瀏覽器打開 (http://localhost:6006/)需要注意,如果chrome瀏覽器版本過低(<60)可能會出現錯誤。

Tensorboard一覽

Tensorboard+PyTorch牛刀小試

進入Tensorboard之後就可以看到可視化的網絡結構,雙擊可以放大子模塊,可以看到更清晰的殘差結構。

Tensorboard+PyTorch牛刀小試

求關注點贊評論呀~


分享到:


相關文章: