如何阅读这个修改后的unet?
How to read this modified unet?
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
class Model_Down(nn.Module):
"""
Convolutional (Downsampling) Blocks.
nd = Number of Filters
kd = Kernel size
"""
def __init__(self,in_channels, nd = 128, kd = 3, padding = 1, stride = 2):
super(Model_Down,self).__init__()
self.padder = nn.ReflectionPad2d(padding)
self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = nd, kernel_size = kd, stride = stride)
self.bn1 = nn.BatchNorm2d(nd)
self.conv2 = nn.Conv2d(in_channels = nd, out_channels = nd, kernel_size = kd, stride = 1)
self.bn2 = nn.BatchNorm2d(nd)
self.relu = nn.LeakyReLU()
def forward(self, x):
x = self.padder(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.padder(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class Model_Skip(nn.Module):
"""
Skip Connections
ns = Number of filters
ks = Kernel size
"""
def __init__(self,in_channels = 128, ns = 4, ks = 1, padding = 0, stride = 1):
super(Model_Skip, self).__init__()
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = ns, kernel_size = ks, stride = stride, padding = padding)
self.bn = nn.BatchNorm2d(ns)
self.relu = nn.LeakyReLU()
def forward(self,x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Model_Up(nn.Module):
"""
Convolutional (Downsampling) Blocks.
nd = Number of Filters
kd = Kernel size
"""
def __init__(self, in_channels = 132, nu = 128, ku = 3, padding = 1):
super(Model_Up, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.padder = nn.ReflectionPad2d(padding)
self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = nu, kernel_size = ku, stride = 1, padding = 0)
self.bn2 = nn.BatchNorm2d(nu)
self.conv2 = nn.Conv2d(in_channels = nu, out_channels = nu, kernel_size = 1, stride = 1, padding = 0) #According to supmat.pdf ku = 1 for second layer
self.bn3 = nn.BatchNorm2d(nu)
self.relu = nn.LeakyReLU()
def forward(self,x):
x = self.bn1(x)
x = self.padder(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.relu(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
return x
class Model(nn.Module):
def __init__(self, length = 5, in_channels = 32, out_channels = 3, nu = [128,128,128,128,128] , nd =
[128,128,128,128,128], ns = [4,4,4,4,4], ku = [3,3,3,3,3], kd = [3,3,3,3,3], ks = [1,1,1,1,1]):
super(Model,self).__init__()
assert length == len(nu), 'Hyperparameters do not match network depth.'
self.length = length
self.downs = nn.ModuleList([Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i]) if i != 0 else
Model_Down(in_channels = in_channels, nd = nd[i], kd = kd[i]) for i in range(self.length)])
self.skips = nn.ModuleList([Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i]) for i in range(self.length)])
self.ups = nn.ModuleList([Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i]) if i != self.length-1 else
Model_Up(in_channels = ns[i], nu = nu[i], ku = ku[i]) for i in range(self.length-1,-1,-1)]) #Elements ordered backwards
self.conv_out = nn.Conv2d(nu[0],out_channels,1,padding = 0)
self.sigm = nn.Sigmoid()
def forward(self,x):
s = [] #Skip Activations
#Downpass
for i in range(self.length):
x = self.downs[i].forward(x)
s.append(self.skips[i].forward(x))
#Uppass
for i in range(self.length):
if (i == 0):
x = self.ups[i].forward(s[-1])
else:
x = self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = self.sigm(self.conv_out(x)) #Squash to RGB ([0,1]) format
return x
此代码是我正在修改的 UNet
。我面临着难以阅读和理解的代码以及跳过连接如何连接到上采样。任何人都可以解释一下吗?或者可以在没有 nn.ModuleList
.
的情况下以更简单易懂的方式编写
有人可以用图表展示一下这个网络吗?
这是 github link 回购 link 我从那里获取这段代码并试图理解它。
这是主要模型 forward(x)
方法的功能等价物。它更加冗长,但它正在“解开”操作流程,使其更容易理解。
我假设列表参数的长度总是 5
(i 在 [0, 4] 范围内,包括在内)所以我可以正确解压(并且它遵循默认参数集).
def unet_function(x, in_channels = 32, out_channels = 3, nu = [128,128,128,128,128],
nd = [128,128,128,128,128], ns = [4,4,4,4,4], ku = [3,3,3,3,3],
kd = [3,3,3,3,3], ks = [1,1,1,1,1]):
################################
# DOWN PASS ####################
################################
#########
# i = 0 #
#########
# First Down
# Model_Down(in_channels = in_channels, nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=in_channels, out_channels=nd[0], kernel_size=kd[0], stride=2)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[0], out_channels=nd[0], kernel_size = kd[0], stride=1)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
# First skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s0 = nn.Conv2D(in_channels=nd[0], out_channels=ns[0])(x)
s0 = nn.BatchNorm2d(ns[0])(s0)
s0 = nn.LeakyreLU()(s0)
#########
# i = 1 #
#########
# Second Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[0], out_channels=nd[0], kernel_size=kd[1], stride=2)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[0], out_channels=nd[0], kernel_size = kd[1], stride=1)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
# Second skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s1 = nn.Conv2D(in_channels=nd[1], out_channels=ns[1])(x)
s1 = nn.BatchNorm2d(ns[1])(s1)
s1 = nn.LeakyreLU()(s1)
#########
# i = 2 #
#########
# Third Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[1], out_channels=nd[1], kernel_size=kd[2], stride=2)(x)
x = nn.BatchNorm2d(nd[1])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[1], out_channels=nd[0], kernel_size = kd[2], stride=1)(x)
x = nn.BatchNorm2d(nd[1])(x)
x = nn.LeakyRelu()(x)
# Third skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s2 = nn.Conv2D(in_channels=nd[2], out_channels=ns[2])(x)
s2 = nn.BatchNorm2d(ns[2])(s2)
s2 = nn.LeakyreLU()(s2)
#########
# i = 3 #
#########
# Fourth Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[2], out_channels=nd[2], kernel_size=kd[3], stride=2)(x)
x = nn.BatchNorm2d(nd[2])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[2], out_channels=nd[2], kernel_size = kd[3], stride=1)(x)
x = nn.BatchNorm2d(nd[2])(x)
x = nn.LeakyRelu()(x)
# Fourth skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s3 = nn.Conv2D(in_channels=nd[3], out_channels=ns[3])(x)
s3 = nn.BatchNorm2d(ns[3])(s3)
s3 = nn.LeakyreLU()(s3)
#########
# i = 4 #
#########
# Fifth Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[3], out_channels=nd[3], kernel_size=kd[4], stride=2)(x)
x = nn.BatchNorm2d(nd[3])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[3], out_channels=nd[3], kernel_size = kd[4], stride=1)(x)
x = nn.BatchNorm2d(nd[2])(x)
x = nn.LeakyRelu()(x)
# Fifth skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
x = nn.Conv2D(in_channels=nd[4], out_channels=ns[4])(x)
x = nn.BatchNorm2d(ns[4])(x)
x = nn.LeakyreLU()(x)
################################
# UP PASS ######################
################################
#########
# i = 4 #
#########
# First Up
# Model_Up(in_channels = ns[i], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[4])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[4], out_channels=nu[4], kernel_size=ku[4], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[4])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = nu[4], out_channels=nu[4], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[4])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 3 #
#########
# Second Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s3], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[3]+nu[4])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[3]+nu[4], out_channels=nu[3], kernel_size=ku[3], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[3])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = ns[3]+nu[4], out_channels=nu[3], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[3])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 2 #
#########
# Third Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s2], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[2]+nu[3])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[2]+nu[3], out_channels=nu[2], kernel_size=ku[2], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[2])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = ns[2]+nu[3], out_channels=nu[2], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[2])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 1 #
#########
# Fourth Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s1], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[1]+nu[2])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[1]+nu[2], out_channels=nu[1], kernel_size=ku[1], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[1])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = ns[1]+nu[2], out_channels=nu[1], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[1])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 0 #
#########
# Fifth Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s0], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[0]+nu[1])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[0]+nu[1], out_channels=nu[0], kernel_size=ku[0], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[0])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = nu[0], out_channels=nu[0], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[0])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
################################
# OUT ##########################
################################
x = nn.Conv2d(in_channels=nu[0], out_channels=out_channels, kernel_size=1, padding = 0)
return nn.Sigmoid()(x) #Squash to RGB ([0,1]) format
最重要的两个部分是:
skips
其中张量x
在代码的并行部分处理,不打扰主x "pathway"
.
从skip
部分产生的张量然后反馈到“主通路 “从最后一个开始。我将这些张量保留为单独的变量s0 to s3
,这样更明显。
从这张图可以清楚的看到下面的部分喂养后面的部分。 s0
是最长的灰色箭头,它连接到最后一个卷积层组之前的“main pathway”。
(不一样的U-Net)
你也可以从中理解为什么我们不需要存储一个 s4
:它直接馈送到下一层,因此不需要将它存储为一个单独的变量。
Module
版本确实存储了它,但这只是因为它方便地存储在最后以相反顺序读取的列表中。将它们存储在列表中的另一个明显原因是我们可以通过相应地更改参数来拥有任意数量的 Up
和 Down
部分。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
class Model_Down(nn.Module):
"""
Convolutional (Downsampling) Blocks.
nd = Number of Filters
kd = Kernel size
"""
def __init__(self,in_channels, nd = 128, kd = 3, padding = 1, stride = 2):
super(Model_Down,self).__init__()
self.padder = nn.ReflectionPad2d(padding)
self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = nd, kernel_size = kd, stride = stride)
self.bn1 = nn.BatchNorm2d(nd)
self.conv2 = nn.Conv2d(in_channels = nd, out_channels = nd, kernel_size = kd, stride = 1)
self.bn2 = nn.BatchNorm2d(nd)
self.relu = nn.LeakyReLU()
def forward(self, x):
x = self.padder(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.padder(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class Model_Skip(nn.Module):
"""
Skip Connections
ns = Number of filters
ks = Kernel size
"""
def __init__(self,in_channels = 128, ns = 4, ks = 1, padding = 0, stride = 1):
super(Model_Skip, self).__init__()
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = ns, kernel_size = ks, stride = stride, padding = padding)
self.bn = nn.BatchNorm2d(ns)
self.relu = nn.LeakyReLU()
def forward(self,x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Model_Up(nn.Module):
"""
Convolutional (Downsampling) Blocks.
nd = Number of Filters
kd = Kernel size
"""
def __init__(self, in_channels = 132, nu = 128, ku = 3, padding = 1):
super(Model_Up, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.padder = nn.ReflectionPad2d(padding)
self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = nu, kernel_size = ku, stride = 1, padding = 0)
self.bn2 = nn.BatchNorm2d(nu)
self.conv2 = nn.Conv2d(in_channels = nu, out_channels = nu, kernel_size = 1, stride = 1, padding = 0) #According to supmat.pdf ku = 1 for second layer
self.bn3 = nn.BatchNorm2d(nu)
self.relu = nn.LeakyReLU()
def forward(self,x):
x = self.bn1(x)
x = self.padder(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.relu(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
return x
class Model(nn.Module):
def __init__(self, length = 5, in_channels = 32, out_channels = 3, nu = [128,128,128,128,128] , nd =
[128,128,128,128,128], ns = [4,4,4,4,4], ku = [3,3,3,3,3], kd = [3,3,3,3,3], ks = [1,1,1,1,1]):
super(Model,self).__init__()
assert length == len(nu), 'Hyperparameters do not match network depth.'
self.length = length
self.downs = nn.ModuleList([Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i]) if i != 0 else
Model_Down(in_channels = in_channels, nd = nd[i], kd = kd[i]) for i in range(self.length)])
self.skips = nn.ModuleList([Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i]) for i in range(self.length)])
self.ups = nn.ModuleList([Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i]) if i != self.length-1 else
Model_Up(in_channels = ns[i], nu = nu[i], ku = ku[i]) for i in range(self.length-1,-1,-1)]) #Elements ordered backwards
self.conv_out = nn.Conv2d(nu[0],out_channels,1,padding = 0)
self.sigm = nn.Sigmoid()
def forward(self,x):
s = [] #Skip Activations
#Downpass
for i in range(self.length):
x = self.downs[i].forward(x)
s.append(self.skips[i].forward(x))
#Uppass
for i in range(self.length):
if (i == 0):
x = self.ups[i].forward(s[-1])
else:
x = self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = self.sigm(self.conv_out(x)) #Squash to RGB ([0,1]) format
return x
此代码是我正在修改的 UNet
。我面临着难以阅读和理解的代码以及跳过连接如何连接到上采样。任何人都可以解释一下吗?或者可以在没有 nn.ModuleList
.
有人可以用图表展示一下这个网络吗?
这是 github link 回购 link 我从那里获取这段代码并试图理解它。
这是主要模型 forward(x)
方法的功能等价物。它更加冗长,但它正在“解开”操作流程,使其更容易理解。
我假设列表参数的长度总是 5
(i 在 [0, 4] 范围内,包括在内)所以我可以正确解压(并且它遵循默认参数集).
def unet_function(x, in_channels = 32, out_channels = 3, nu = [128,128,128,128,128],
nd = [128,128,128,128,128], ns = [4,4,4,4,4], ku = [3,3,3,3,3],
kd = [3,3,3,3,3], ks = [1,1,1,1,1]):
################################
# DOWN PASS ####################
################################
#########
# i = 0 #
#########
# First Down
# Model_Down(in_channels = in_channels, nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=in_channels, out_channels=nd[0], kernel_size=kd[0], stride=2)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[0], out_channels=nd[0], kernel_size = kd[0], stride=1)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
# First skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s0 = nn.Conv2D(in_channels=nd[0], out_channels=ns[0])(x)
s0 = nn.BatchNorm2d(ns[0])(s0)
s0 = nn.LeakyreLU()(s0)
#########
# i = 1 #
#########
# Second Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[0], out_channels=nd[0], kernel_size=kd[1], stride=2)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[0], out_channels=nd[0], kernel_size = kd[1], stride=1)(x)
x = nn.BatchNorm2d(nd[0])(x)
x = nn.LeakyRelu()(x)
# Second skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s1 = nn.Conv2D(in_channels=nd[1], out_channels=ns[1])(x)
s1 = nn.BatchNorm2d(ns[1])(s1)
s1 = nn.LeakyreLU()(s1)
#########
# i = 2 #
#########
# Third Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[1], out_channels=nd[1], kernel_size=kd[2], stride=2)(x)
x = nn.BatchNorm2d(nd[1])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[1], out_channels=nd[0], kernel_size = kd[2], stride=1)(x)
x = nn.BatchNorm2d(nd[1])(x)
x = nn.LeakyRelu()(x)
# Third skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s2 = nn.Conv2D(in_channels=nd[2], out_channels=ns[2])(x)
s2 = nn.BatchNorm2d(ns[2])(s2)
s2 = nn.LeakyreLU()(s2)
#########
# i = 3 #
#########
# Fourth Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[2], out_channels=nd[2], kernel_size=kd[3], stride=2)(x)
x = nn.BatchNorm2d(nd[2])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[2], out_channels=nd[2], kernel_size = kd[3], stride=1)(x)
x = nn.BatchNorm2d(nd[2])(x)
x = nn.LeakyRelu()(x)
# Fourth skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
s3 = nn.Conv2D(in_channels=nd[3], out_channels=ns[3])(x)
s3 = nn.BatchNorm2d(ns[3])(s3)
s3 = nn.LeakyreLU()(s3)
#########
# i = 4 #
#########
# Fifth Down
# Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2D(in_channels=nd[3], out_channels=nd[3], kernel_size=kd[4], stride=2)(x)
x = nn.BatchNorm2d(nd[3])(x)
x = nn.LeakyRelu()(x)
x = nn.ReflectionPad2d(padding=1)(x)
x = nn.Conv2d(in_channels = nd[3], out_channels=nd[3], kernel_size = kd[4], stride=1)(x)
x = nn.BatchNorm2d(nd[2])(x)
x = nn.LeakyRelu()(x)
# Fifth skip
# Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
x = nn.Conv2D(in_channels=nd[4], out_channels=ns[4])(x)
x = nn.BatchNorm2d(ns[4])(x)
x = nn.LeakyreLU()(x)
################################
# UP PASS ######################
################################
#########
# i = 4 #
#########
# First Up
# Model_Up(in_channels = ns[i], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[4])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[4], out_channels=nu[4], kernel_size=ku[4], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[4])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = nu[4], out_channels=nu[4], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[4])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 3 #
#########
# Second Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s3], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[3]+nu[4])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[3]+nu[4], out_channels=nu[3], kernel_size=ku[3], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[3])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = ns[3]+nu[4], out_channels=nu[3], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[3])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 2 #
#########
# Third Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s2], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[2]+nu[3])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[2]+nu[3], out_channels=nu[2], kernel_size=ku[2], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[2])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = ns[2]+nu[3], out_channels=nu[2], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[2])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 1 #
#########
# Fourth Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s1], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[1]+nu[2])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[1]+nu[2], out_channels=nu[1], kernel_size=ku[1], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[1])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = ns[1]+nu[2], out_channels=nu[1], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[1])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
#########
# i = 0 #
#########
# Fifth Up
# self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
x = torch.cat([x,s0], axis=1) # IMPORTANT HERE
# Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
x = nn.BatchNorm2d(in_channel=ns[0]+nu[1])(x)
x = nn.ReflectionPad2d(padding)(x)
x = nn.Conv2d(in_channels=ns[0]+nu[1], out_channels=nu[0], kernel_size=ku[0], stride=1, padding=0)(x)
x = nn.BatchNorm2d(nu[0])(x)
x = nn.LeakyReLU()(x)
x = nn.Conv2d(in_channels = nu[0], out_channels=nu[0], kernel_size = 1, stride = 1, padding = 0)(x)
x = nn.BatchNorm2d(nu[0])(x)
x = nn.LeakyReLU()(x)
x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
################################
# OUT ##########################
################################
x = nn.Conv2d(in_channels=nu[0], out_channels=out_channels, kernel_size=1, padding = 0)
return nn.Sigmoid()(x) #Squash to RGB ([0,1]) format
最重要的两个部分是:
skips
其中张量x
在代码的并行部分处理,不打扰主x "pathway"
.从
skip
部分产生的张量然后反馈到“主通路 “从最后一个开始。我将这些张量保留为单独的变量s0 to s3
,这样更明显。
从这张图可以清楚的看到下面的部分喂养后面的部分。 s0
是最长的灰色箭头,它连接到最后一个卷积层组之前的“main pathway”。
你也可以从中理解为什么我们不需要存储一个 s4
:它直接馈送到下一层,因此不需要将它存储为一个单独的变量。
Module
版本确实存储了它,但这只是因为它方便地存储在最后以相反顺序读取的列表中。将它们存储在列表中的另一个明显原因是我们可以通过相应地更改参数来拥有任意数量的 Up
和 Down
部分。