本文提出了一种语义图像合成方法SPADE,能够将语义分割mask图像转换为真实图像。作者发现通常的归一化层倾向于“洗掉”输入语义mask图像中包含的信息,因此提出了空间自适应归一化(spatially-adaptive normalization, SPADE)层,它通过空间自适应地学习和转换输入语义mask图像的信息。
\[x = \gamma \cdot \frac{x - \mu(x)}{\sigma(x)}+\beta\]SPADE采用的归一化形式为BatchNorm,即沿着特征的每一个通道维度进行归一化。仿射变换参数$\gamma,\beta$不是标量,而是与空间位置有关的向量$\gamma_{c,x,y},\beta_{c,x,y}$,并由输入语义mask图像通过两层卷积层构造。
# ConvLayer (predicts SPADE parameters)
class ConvLayer(nn.Module):
def __init__(self, input_dim, output_dim, dim=128):
super(ConvLayer, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(input_dim, dim, 3, 1, 1),
nn.Conv2d(dim, output_dim, 3, 1, 1),
def forward(self, x):
return self.model(x)
# SPADE module
class SPADE2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(SPADE2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None # [1, c, h, w]
self.bias = None # [1, c, h, w]
self.bn = nn.BatchNorm2d(
self.num_features, eps=1e-5,
momentum=0.1, affine=False,
def forward(self, x):
assert (
self.weight is not None and self.bias is not None
), "Please assign weight and bias before calling SPADE!"
# Apply batch norm
out = self.bn(out)
return out*self.weight + self.bias
def __repr__(self):
return self.__class__.__name__ + "(" + str(self.num_features) + ")"
# Model
class Model(nn.Module):
def __init__(self, ):
super(Model, self).__init__()
# 定义包含SPADE的主体网络
self.model = nn.Sequential()
# 定义生成SPADE参数的网络
num_spade_params = self.get_num_spade_params()
self.conv = ConvLayer(input_channel, num_spade_params)
def get_num_spade_params(self):
"""Return the number of SPADE parameters needed by the model"""
num_spade_params = 0
for m in self.modules():
if m.__class__.__name__ == "SPADE2d":
num_spade_params += 2 * m.num_features
return num_spade_params
def assign_spade_params(self, spade_params):
"""Assign the spade_params to the SPADE layers in model"""
for m in self.modules():
if m.__class__.__name__ == "SPADE2d":
# Extract weight and bias predictions
m.weight = spade_params[:, : m.num_features, :, :].contiguous()
m.bias = spade_params[:, m.num_features : 2 * m.num_features, :, :].contiguous()
# Move pointer
if spade_params.size(1) > 2*m.num_features:
spade_params = spade_params[:, 2*m.num_features:, :, :]
def forward(self, main_input, cond_input):
# Update SPADE parameters by ConvLayer prediction based off conditional input
out = self.model(main_input)
return out