TUNIT:完全无监督图像到图像翻译.
图像到图像的翻译(image-to-image translation)是指将属于一种图像域的图像转换到另一种图像域(如风格迁移)。图像到图像翻译有三种不同的监督程度:
- 图像等级(image-level)的监督:这种监督是为每张图像指定另一张图像作为标签,从而实现一对一的转换,如把鞋的轮廓图转换成真实的鞋图像;
- 集合等级(set-level)的监督:这种监督是为每一类图像集合指定一个标签,如属于某一种动物。这种监督方式需要人为对数据集进行划分,并且指定标签类型;
- 本文提出了一种完全无监督的转换方法,提供一个图像数据集,实现其中任意两张图像之间的翻译。
1. 网络结构
所提出的模型如下图所示。该模型由一个编码网络$E$(论文中也称guided network)、一个生成网络$G$和一个判别网络$D$组成。
- 编码网络$E$接收一张图像,生成其对应的伪标签(pseudo label)和风格编码(style code)。伪标签指示该图像所属的图像域类别(用作聚类),风格编码指示该图像所具有的风格(每张图像都不同,用作风格迁移)。
- 生成网络$G$接收一个源域图像和一个目标域图像的风格编码,生成具有目标域风格的源域图像。
- 判别网络$D$根据领域标签判定该领域中图像的真实性。值得一提的是,$D$具有$K$个输出头,用于分别处理$K$种图像域的情况,在使用时由伪标签选定,$K$的取值需要人工选择。
2. 训练过程
网络的训练过程有三个阶段。
⚪ 编码网络$E$的训练
首先预训练编码器,编码器的领域标签头用Invariant Information Clustering方法进行无监督的聚类训练;编码器的风格编码头用MoCo方法进行对比学习。
编码网络$E$的领域标签头训练采用无监督的聚类方法。随机选择一张图像$x$,通过随机数据增强(如随机裁剪,水平翻转)得到图像$x^+$,将其通过编码网络$E$得到其类别伪标签$p=E_{\text{class}}(x)$和$p^+=E_{\text{class}}(x^+)$,并最大化其互信息(mutual information):
\[\mathcal{L}_{\text{MI}} = I(p,p^+) = H(p) - H(p|p^+)\]对互信息的最大化可以分成两步。第一步,最大化类别伪标签$p$的熵$H(p)$,即使得所有图像的伪标签具有尽可能大的差异;第一步,最小化条件熵$H(p|p^+)$,即使得数据增强的图像(其风格没有变化)与原图的伪标签尽可能接近。
此外对编码网络$E$引入对比损失(contrastive loss),使其能够学习到所属图像域的风格。将图像$x$及其增强图像$x^+$通过编码网络$E$得到其风格编码$s=E_{\text{style}}(x)$和$s^+=E_{\text{style}}(x^+)$,并随机选择$N+1$张其他图像获得其风格编码$s^-=E_{\text{style}}(x^-)$,计算对比损失:
\[\mathcal{L}_{\text{style}}^E = -log \frac{exp(s \cdot s^+/ \tau)}{\sum_{i=0}^{N}exp(s \cdot s_i^-/ \tau)}\]⚪ 判别网络$D$的训练
其次训练判别器,判别器采用标准的对抗损失训练,接收图像后生成长度为领域数量的预测向量,根据编码器提供的领域标签选择对应的位置构造对抗损失。
判别网络$D$的训练采用标准的对抗损失(adversarial)。给定源域图像$x$及其对应的类别伪标签$y$,以及目标域图像$\tilde{x}$及其对应的类别伪标签$\tilde{y}$,并计算其风格编码$\tilde{s}=E_{\text{style}}(\tilde{x})$。据此根据源域图像$x$和目标域风格编码$\tilde{s}$生成图像$G(x,\tilde{s})$。则对抗损失计算为:
\[\mathcal{L}_{\text{adv}} = \Bbb{E}_{x \text{~} p_{data}(x)}[logD_y(x)]+\Bbb{E}_{x,\tilde{x} \text{~} p_{data}(x)}[log(1-D_{\tilde{y}}(G(x,\tilde{s})))]\]该对抗损失的计算使用了生成网络,因此梯度更新时也会反向传播作用于生成网络。
⚪ 生成网络$G$的训练
生成网络$G$的训练首先采用风格对比损失(style contrastive loss)。将生成图像$G(x,\tilde{s})$再次送入编码网络得到对应的预测风格编码$s’$,使得该编码与目标域风格编码$\tilde{s}$足够接近,且与其余$N+1$张其他图像的风格编码$s^-$足够疏远:
\[\mathcal{L}_{\text{style}}^G = \Bbb{E}_{x,\tilde{x} \text{~} p(x)} [-\log \frac{\exp(s' \cdot \tilde{s})}{\sum_{i=0}^N \exp(s' \cdot s_i^- / \tau)}]\]其次采用图像重构损失(image reconstruction loss),需要先用编码器获取输入图像的风格编码,这一步导致编码器也能获得梯度,从而一起训练:
\[\begin{aligned} \mathcal{L}_{\text{recon}} &= \Bbb{E}_{x \text{~} p(x)}[||x-G(x,s)||_1 ] \end{aligned}\]⚪ 总的损失函数
采用联合训练(joint training),同时训练编码网络$E$、生成网络$G$和判别网络$D$。其目标函数如下:
\[\begin{aligned} \mathop{ \min}_{G,E} \mathop{\max}_{D} &\mathcal{L}_{\text{adv}}(D,G) + \lambda_{\text{style}}^G\mathcal{L}_{\text{style}}^G(G,E) + \lambda_{\text{rec}}\mathcal{L}_{\text{rec}}(G,E) \\ & - \lambda_{MI}\mathcal{L}_{MI}(E)+ \lambda_{\text{style}}^E\mathcal{L}_{\text{style}}^E(E) \end{aligned}\]3. 实验分析
除了联合训练,作者也进行了分离训练(separate training),即先训练编码网络,再训练生成网络和判别网络。两种训练方式的结果如下。联合训练相比于分离训练,能够学习到更好的无监督特征表示(表现为风格编码的可视化分散且成簇),并且训练精度更高(表现为FID指标更小)。
作者将不同图像的风格编码可视化后,按照其对应的类别伪标签进行分类,展示不同类别的图像如下。观察到具有相似纹理的同一类动物被划分到同一个类别中,并且同类动物对应的风格编码在空间中是接近的。
作者展示了一些实验结果,该方法能够实现无监督的图像到图像转换,并且取得较高的生成图像质量。