生成对抗网络的自调制.
本文针对GAN提出了一种自调制(Self-Modulation)模块,用于增强训练过程中的稳定性。
自调制模块受条件GAN结构中条件BatchNorm的启发,把网络中的BN替换为条件BN。
条件BN是指对输入特征$h$沿通道维度进行归一化后,由外部输入决定仿射参数$\gamma,\beta$,并进行反归一化:
\[h'_i = \gamma \frac{h_i-\mu(h_i)}{\sigma(h_i)} - \beta\]在自调制模块中,仿射参数$\gamma,\beta$是由生成器的输入噪声$z$构造的。
仿射参数$\gamma,\beta$可以通过全连接网络实现,并且中间层的维度可以取得更小一些,比如$32$,不会明显增加参数量。
######################################
# MLP (predicts Affine parameters)
######################################
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, dim=32):
super(MLP, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, dim),
nn.Linear(dim, output_dim),
)
def forward(self, x):
return self.model(x)
应用自调制模块的网络生成器结构如下:
######################################
# Self-Modulation module
######################################
class SelfMod2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(SelfMod2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None # [1, c]
self.bias = None # [1, c]
self.bn = nn.BatchNorm2d(
self.num_features, eps=1e-5,
momentum=0.1, affine=False,
)
def forward(self, x):
assert (
self.weight is not None and self.bias is not None
), "Please assign weight and bias before calling SPADE!"
# Apply batch norm
out = self.bn(out)
return out*self.weight + self.bias
#################################
# Model
#################################
class Model(nn.Module):
def __init__(self, input_channel):
super(Model, self).__init__()
# 定义包含Self-Modulation的主体网络
self.model = nn.Sequential()
# 定义生成Self-Modulation参数的网络
num_selfmod_params = self.get_num_selfmod_params()
self.mlp = MLP(input_channel, num_selfmod_params)
def get_num_selfmod_params(self):
"""Return the number of SelfMod2d parameters needed by the model"""
num_selfmod_params = 0
for m in self.modules():
if m.__class__.__name__ == "SelfMod2d":
num_selfmod_params += 2 * m.num_features
return num_selfmod_params
def assign_selfmod_params(self, selfmod_params):
"""Assign the selfmod_params to the SelfMod2d layers in model"""
for m in self.modules():
if m.__class__.__name__ == "SelfMod2d":
# Extract weight and bias predictions
m.weight = selfmod_params[:, : m.num_features, :, :].contiguous()
m.bias = selfmod_params[:, m.num_features : 2 * m.num_features, :, :].contiguous()
# Move pointer
if selfmod_params.size(1) > 2*m.num_features:
selfmod_params = selfmod_params[:, 2*m.num_features:, :, :]
def forward(self, z):
# Update SelfMod2d parameters by ConvLayer prediction based off conditional input
self.assign_selfmod_params(self.mlp(z))
out = self.model(z)
return out