生成对抗网络的自调制.

本文针对GAN提出了一种自调制(Self-Modulation)模块,用于增强训练过程中的稳定性。

自调制模块受条件GAN结构中条件BatchNorm的启发,把网络中的BN替换为条件BN

条件BN是指对输入特征$h$沿通道维度进行归一化后,由外部输入决定仿射参数$\gamma,\beta$,并进行反归一化:

\[h'_i = \gamma \frac{h_i-\mu(h_i)}{\sigma(h_i)} - \beta\]

在自调制模块中,仿射参数$\gamma,\beta$是由生成器的输入噪声$z$构造的。

仿射参数$\gamma,\beta$可以通过全连接网络实现,并且中间层的维度可以取得更小一些,比如$32$,不会明显增加参数量。

######################################
#   MLP (predicts Affine parameters)
######################################
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, dim=32):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, dim),
            nn.Linear(dim, output_dim),
            )

    def forward(self, x):
        return self.model(x)

应用自调制模块的网络生成器结构如下:

######################################
#   Self-Modulation module
######################################
class SelfMod2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(SelfMod2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # weight and bias are dynamically assigned
        self.weight = None # [1, c]
        self.bias = None # [1, c]
        self.bn = nn.BatchNorm2d(
            self.num_features, eps=1e-5,
            momentum=0.1, affine=False,
            )

    def forward(self, x):
        assert (
            self.weight is not None and self.bias is not None
        ), "Please assign weight and bias before calling SPADE!"
        # Apply batch norm
        out = self.bn(out)
        return out*self.weight + self.bias


#################################
#            Model
#################################
class Model(nn.Module):
    def __init__(self, input_channel):
        super(Model, self).__init__()
        # 定义包含Self-Modulation的主体网络
        self.model = nn.Sequential()
        # 定义生成Self-Modulation参数的网络
        num_selfmod_params = self.get_num_selfmod_params()
        self.mlp = MLP(input_channel, num_selfmod_params)

    def get_num_selfmod_params(self):
        """Return the number of SelfMod2d parameters needed by the model"""
        num_selfmod_params = 0
        for m in self.modules():
            if m.__class__.__name__ == "SelfMod2d":
                num_selfmod_params += 2 * m.num_features
        return num_selfmod_params

    def assign_selfmod_params(self, selfmod_params):
        """Assign the selfmod_params to the SelfMod2d layers in model"""
        for m in self.modules():
            if m.__class__.__name__ == "SelfMod2d":
                # Extract weight and bias predictions
                m.weight = selfmod_params[:, : m.num_features, :, :].contiguous()
                m.bias = selfmod_params[:, m.num_features : 2 * m.num_features, :, :].contiguous()
                # Move pointer
                if selfmod_params.size(1) > 2*m.num_features:
                    selfmod_params = selfmod_params[:, 2*m.num_features:, :, :]

    def forward(self, z):
        # Update SelfMod2d parameters by ConvLayer prediction based off conditional input
        self.assign_selfmod_params(self.mlp(z))
        out = self.model(z)
        return out