GAN-BERT:使用GAN进行半监督的BERT训练.

BERT是一种常用的预训练语言模型,其结构为Transformer的编码器。尽管BERT在预训练时是无监督的,当其应用到下游任务(如文本分类)时需要在人工标注的数据集上进行微调。当下游任务的训练数据集样本有限时,微调BERT往往不能充分发挥其性能。在本文中,作者提出了一种半监督的BERT微调训练方法,通过引入对抗生成网络,使得在数据集中仅有少量标注的情况下也能泛化到这些任务中。

标准的BERT结构可以进一步分成一个特征提取部分和一个分类部分。特征提取部分对每个输入token生成其对应的输出特征,分类部分根据这些特征执行下游任务,如文本分类中使用多层感知机进一步对头部的CLASS token进行分类。作者提出的GAN-BERT如下图所示。将BERT的分类部分看作判别器$\mathcal{D}$,额外引入一个生成器$\mathcal{G}$用于生成假的中间特征。判别器判断其输入特征是真(来源于BERT)还是假(来源于生成器),并进一步对真的特征预测其所属类别。

记$p_d$为真实的样本分布,其中的标注数据标签为$1$~$k$,并额外引入一个标签$k+1$表示是否是真实数据。对于判别器$\mathcal{D}$,在有标签的监督数据上计算监督分类损失:

\[\mathcal{L}_{\mathcal{D}_{sup.}} = -\Bbb{E}_{x,y\text{~}p_d}log[p_m(\hat{y}=y|x,y \in (1,...,k))]\]

另外在所有真实数据上计算非监督对抗损失:

\[\mathcal{L}_{\mathcal{D}_{unsup.}} = -\Bbb{E}_{x\text{~}p_d}log[1-p_m(\hat{y}=y|x,y =k+1)]-\Bbb{E}_{x\text{~}\mathcal{G}}log[p_m(\hat{y}=y|x,y =k+1)]\]

判别器$\mathcal{D}$的总损失为上述两个损失之和:

\[\mathcal{L}_{\mathcal{D}} = \mathcal{L}_{\mathcal{D}_{sup.}} + \mathcal{L}_{\mathcal{D}_{unsup.}}\]

对于生成器$\mathcal{G}$,一方面希望其生成的特征足够接近真实特征:

\[\mathcal{L}_{\mathcal{G}_{\text{feature matching}}} = ||\Bbb{E}_{x\text{~}\mathcal{p_d}}f(x)-\Bbb{E}_{x\text{~}\mathcal{G}}f(x)||_2^2\]

另一方面,希望其生成的特征能够骗过判别器:

\[\mathcal{L}_{\mathcal{G}_{unsup.}} = -\Bbb{E}_{x\text{~}\mathcal{G}}log[1-p_m(\hat{y}=y|x,y =k+1)]\]

生成器$\mathcal{G}$的总损失为上述两个损失之和:

\[\mathcal{L}_{\mathcal{G}} = \mathcal{L}_{\mathcal{G}_{\text{feature matching}}} + \mathcal{L}_{\mathcal{G}_{unsup.}}\]

作者在多个数据集中进行实验。实验表明,在仅有极少数标注的数据集上(如$1\%$),BERT的性能很差,而GAN-BERT仍然具有比较好的性能。