通过空间自适应归一化进行语义图像合成.
本文提出了一种语义图像合成方法SPADE,能够将语义分割mask图像转换为真实图像。作者发现通常的归一化层倾向于“洗掉”输入语义mask图像中包含的信息,因此提出了空间自适应归一化(spatially-adaptive normalization, SPADE)层,它通过空间自适应地学习和转换输入语义mask图像的信息。
SPADE是一种条件归一化方法。条件归一化是指对特征进行归一化后,通过外部数据的统计信息进行反归一化;一般操作如下:首先将该层特征归一化为零均值和单位方差,然后通过使用从外部数据学习的参数进行仿射变换对归一化特征进行反归一化。
\[x = \gamma \cdot \frac{x - \mu(x)}{\sigma(x)}+\beta\]SPADE采用的归一化形式为BatchNorm,即沿着特征的每一个通道维度进行归一化。仿射变换参数$\gamma,\beta$不是标量,而是与空间位置有关的向量$\gamma_{c,x,y},\beta_{c,x,y}$,并由输入语义mask图像通过两层卷积层构造。
作者认为SPADE能够更好地保留输入语义mask图像的语义信息。考虑到mask图像由几个均匀区域组成。如果直接对其使用InstanceNorm,则会丢失语义信息。而SPADE的分割mask是通过仿射变换提供的,没有对其进行归一化,只对前一层特征进行归一化。因此,SPADE可以更好地保留语义信息。
SPADE生成器将随机噪声作为输入,通过带有SPADE的残差块生成图像。由于每个残差块包含上采样层,因此对语义mask进行下采样以匹配空间分辨率。通过多尺度判别器构造对抗损失,形式为hinge损失;此外还使用了重构损失。
除了使用随机向量作为生成器的输入,也可以通过一个编码器将真实图像转换为随机向量,将其作为生成器的输入,共同形成一个VAE结构,其中编码器试图捕获图像的风格,而生成器通过SPADE将编码风格和分割mask信息结合起来生成图像。此时损失函数还包括VAE引入的KL散度项。
向网络中加入SPADE层的参考代码实现如下:
######################################
# ConvLayer (predicts SPADE parameters)
######################################
class ConvLayer(nn.Module):
def __init__(self, input_dim, output_dim, dim=128):
super(ConvLayer, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(input_dim, dim, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(dim, output_dim, 3, 1, 1),
)
def forward(self, x):
return self.model(x)
######################################
# SPADE module
######################################
class SPADE2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(SPADE2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None # [1, c, h, w]
self.bias = None # [1, c, h, w]
self.bn = nn.BatchNorm2d(
self.num_features, eps=1e-5,
momentum=0.1, affine=False,
)
def forward(self, x):
assert (
self.weight is not None and self.bias is not None
), "Please assign weight and bias before calling SPADE!"
# Apply batch norm
out = self.bn(out)
return out*self.weight + self.bias
def __repr__(self):
return self.__class__.__name__ + "(" + str(self.num_features) + ")"
#################################
# Model
#################################
class Model(nn.Module):
def __init__(self, ):
super(Model, self).__init__()
# 定义包含SPADE的主体网络
self.model = nn.Sequential()
# 定义生成SPADE参数的网络
num_spade_params = self.get_num_spade_params()
self.conv = ConvLayer(input_channel, num_spade_params)
def get_num_spade_params(self):
"""Return the number of SPADE parameters needed by the model"""
num_spade_params = 0
for m in self.modules():
if m.__class__.__name__ == "SPADE2d":
num_spade_params += 2 * m.num_features
return num_spade_params
def assign_spade_params(self, spade_params):
"""Assign the spade_params to the SPADE layers in model"""
for m in self.modules():
if m.__class__.__name__ == "SPADE2d":
# Extract weight and bias predictions
m.weight = spade_params[:, : m.num_features, :, :].contiguous()
m.bias = spade_params[:, m.num_features : 2 * m.num_features, :, :].contiguous()
# Move pointer
if spade_params.size(1) > 2*m.num_features:
spade_params = spade_params[:, 2*m.num_features:, :, :]
def forward(self, main_input, cond_input):
# Update SPADE parameters by ConvLayer prediction based off conditional input
self.assign_spade_params(self.conv(cond_input))
out = self.model(main_input)
return out