TUNIT:完全无监督图像到图像翻译.
图像到图像的翻译(image-to-image translation)是指将属于一种图像域的图像转换到另一种图像域(如风格迁移)。图像到图像翻译有三种不同的监督程度:
- 图像等级(image-level)的监督:这种监督是为每张图像指定另一张图像作为标签,从而实现一对一的转换,如把鞋的轮廓图转换成真实的鞋图像;
- 集合等级(set-level)的监督:这种监督是为每一类图像集合指定一个标签,如属于某一种动物。这种监督方式需要人为对数据集进行划分,并且指定标签类型;
- 本文提出了一种完全无监督的转换方法,提供一个图像数据集,实现其中任意两张图像之间的翻译。
1. 网络结构
所提出的模型如下图所示。该模型由一个编码网络EE(论文中也称guided network)、一个生成网络GG和一个判别网络DD组成。
- 编码网络EE接收一张图像,生成其对应的伪标签(pseudo label)和风格编码(style code)。伪标签指示该图像所属的图像域类别(用作聚类),风格编码指示该图像所具有的风格(每张图像都不同,用作风格迁移)。
- 生成网络GG接收一个源域图像和一个目标域图像的风格编码,生成具有目标域风格的源域图像。
- 判别网络DD根据领域标签判定该领域中图像的真实性。值得一提的是,DD具有KK个输出头,用于分别处理KK种图像域的情况,在使用时由伪标签选定,KK的取值需要人工选择。
2. 训练过程
网络的训练过程有三个阶段。
⚪ 编码网络EE的训练
首先预训练编码器,编码器的领域标签头用Invariant Information Clustering方法进行无监督的聚类训练;编码器的风格编码头用MoCo方法进行对比学习。
编码网络EE的领域标签头训练采用无监督的聚类方法。随机选择一张图像xx,通过随机数据增强(如随机裁剪,水平翻转)得到图像x+x+,将其通过编码网络EE得到其类别伪标签p=Eclass(x)p=Eclass(x)和p+=Eclass(x+)p+=Eclass(x+),并最大化其互信息(mutual information):
LMI=I(p,p+)=H(p)−H(p|p+)LMI=I(p,p+)=H(p)−H(p|p+)对互信息的最大化可以分成两步。第一步,最大化类别伪标签pp的熵H(p)H(p),即使得所有图像的伪标签具有尽可能大的差异;第一步,最小化条件熵H(p|p+)H(p|p+),即使得数据增强的图像(其风格没有变化)与原图的伪标签尽可能接近。
此外对编码网络EE引入对比损失(contrastive loss),使其能够学习到所属图像域的风格。将图像xx及其增强图像x+x+通过编码网络EE得到其风格编码s=Estyle(x)s=Estyle(x)和s+=Estyle(x+)s+=Estyle(x+),并随机选择N+1N+1张其他图像获得其风格编码s−=Estyle(x−)s−=Estyle(x−),计算对比损失:
LEstyle=−logexp(s⋅s+/τ)∑Ni=0exp(s⋅s−i/τ)LEstyle=−logexp(s⋅s+/τ)∑Ni=0exp(s⋅s−i/τ)⚪ 判别网络DD的训练
其次训练判别器,判别器采用标准的对抗损失训练,接收图像后生成长度为领域数量的预测向量,根据编码器提供的领域标签选择对应的位置构造对抗损失。
判别网络DD的训练采用标准的对抗损失(adversarial)。给定源域图像xx及其对应的类别伪标签yy,以及目标域图像˜x~x及其对应的类别伪标签˜y~y,并计算其风格编码˜s=Estyle(˜x)~s=Estyle(~x)。据此根据源域图像xx和目标域风格编码˜s~s生成图像G(x,˜s)G(x,~s)。则对抗损失计算为:
Ladv=Ex~pdata(x)[logDy(x)]+Ex,˜x~pdata(x)[log(1−D˜y(G(x,˜s)))]Ladv=Ex~pdata(x)[logDy(x)]+Ex,~x~pdata(x)[log(1−D~y(G(x,~s)))]该对抗损失的计算使用了生成网络,因此梯度更新时也会反向传播作用于生成网络。
⚪ 生成网络G的训练
生成网络G的训练首先采用风格对比损失(style contrastive loss)。将生成图像G(x,˜s)再次送入编码网络得到对应的预测风格编码s′,使得该编码与目标域风格编码˜s足够接近,且与其余N+1张其他图像的风格编码s−足够疏远:
LGstyle=Ex,˜x~p(x)[−logexp(s′⋅˜s)∑Ni=0exp(s′⋅s−i/τ)]其次采用图像重构损失(image reconstruction loss),需要先用编码器获取输入图像的风格编码,这一步导致编码器也能获得梯度,从而一起训练:
Lrecon=Ex~p(x)[||x−G(x,s)||1]⚪ 总的损失函数
采用联合训练(joint training),同时训练编码网络E、生成网络G和判别网络D。其目标函数如下:
minG,EmaxDLadv(D,G)+λGstyleLGstyle(G,E)+λrecLrec(G,E)−λMILMI(E)+λEstyleLEstyle(E)3. 实验分析
除了联合训练,作者也进行了分离训练(separate training),即先训练编码网络,再训练生成网络和判别网络。两种训练方式的结果如下。联合训练相比于分离训练,能够学习到更好的无监督特征表示(表现为风格编码的可视化分散且成簇),并且训练精度更高(表现为FID指标更小)。
作者将不同图像的风格编码可视化后,按照其对应的类别伪标签进行分类,展示不同类别的图像如下。观察到具有相似纹理的同一类动物被划分到同一个类别中,并且同类动物对应的风格编码在空间中是接近的。
作者展示了一些实验结果,该方法能够实现无监督的图像到图像转换,并且取得较高的生成图像质量。
Related Issues not found
Please contact @0809zheng to initialize the comment