DCGAN:使用深度卷积神经网络构造GAN.
本文作者设计了DCGAN(Deep Convolutional GAN),即用卷积神经网络构造生成对抗网络,并用于图像生成任务。
为了稳定DCGAN的训练过程,作者提出了以下几点设计思路:
- 去掉网络中的pooling层,在判别器中使用Strided convolution (步幅卷积)进行下采样,在生成器中使用transposed convolution(转置卷积)进行上采样;
- 在判别器和生成器中使用batch norm;
- 移除网络中的所有全连接层;
- 生成器的输出层使用Tanh激活函数,其他层使用ReLU激活函数;
- 判别器的所有层使用LeakyReLU激活函数.
DCGAN的结构如图所示。在实际实现时有以下几点需要注意:
- 卷积和转置卷积的卷积核大小为$4\times 4$或$5\times 5$;
- 卷积和转置卷积的步长一般都取$2$;
- 对于判别器,第一层卷积后一般不用batch norm,之后采用”Conv2D+BN+LeakyReLU“的组合,直至特征图的大小下采样到$4\times 4$;
- 对于生成器,噪声通过全连接层后大小调整为$4\times 4$,之后采用”DeConv2D+BN+ReLU“的组合;
- 生成器输出采用Tanh激活,对应输入图像的范围$[-1,1]$。
DCGAN的完整pytorch实现可参考PyTorch-GAN,下面介绍判别器和生成器的结构。
⚪ DCGAN的判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
block.append(nn.LeakyReLU(0.2, inplace=True))
return block
self.model = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = int(np.ceil(opt.img_size / (2 ** 4)))
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
⚪ DCGAN的生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = int(opt.img_size / (2 ** 4))
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
def generator_block(in_filters, out_filters):
block = [nn.BatchNorm2d(in_filters)]
block.append(nn.ReLU(inplace=True))
block.append(nn.ConvTranspose2d(in_filters, out_filters, 4, stride=2, padding=1))
return block
self.model = nn.Sequential(
*generator_block(128, 64),
*generator_block(64, 32),
*generator_block(32, 16),
*generator_block(16, opt.channels),
nn.Tanh(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.model(out)
return img