使用能量模型进行概念学习.
- paper:Concept Learning with Energy-Based Models
- arXiv:link
问题阐述
每一个机器学习任务都可以建模为一个能量模型(energy-based model)$E(x)$,若当前状态$x$能够使得模型满意,则能量为$0$;否则模型能量较高。
损失函数就是一种能量模型,当样本和标签对应时,损失函数为$0$;否则通过学习使得损失函数不断降低。给定一个样本,可以通过梯度下降法得到正确的标签。
模型介绍
作者提出了一种能量模型:$E(x,a,w)$,其中:
- $x$是state,代表当前样本的一种状态;
- $a$是attention map,代表对每个样本的注意力;
- $w$是concept vector,代表当前样本所表示的一种概念。
具体地,在实验中变量设置如下:
- $x$是包含二维空间中每一个点的坐标和点的颜色值的向量\(\{(x_i,y_i,r_i,g_i,b_i)\}\);
- $a$是关注哪些点的注意力分布\((a_1,...,a_i,...)\);
- $w$是一个概念编码,不同的概念如组成正方形、在一条直线上、远离某一点…
作者把能量模型建模成relation network:
\[E_θ(x,a,w) = f_θ(\sum_{i,j,t}^{} {σ(a_i)σ(a_j) \cdot g_θ(x_i^t,x_j^t,w)},w)^2\]作者提出了三种任务:
(1)concept inference
给定$x$和$a$,求$w$。即给定当前样本的一个状态和关注哪些点,求这些点代表的一个概念。即求解优化问题:
\[w(a,x) = \mathop{\arg \max}_{w} E(x,a,w)\]可以通过梯度方法实现:
\[w^k = w^{k-1} + \frac{α}{2} ▽_w E(x,a,w) + ω^k, ω^k \text{~} N(0,α)\](2)generation
给定$w$和$a$,求$x$。即给定关注哪些点和一个概念,求这些样本点满足概念的一个状态。即求解优化问题:
\[x(a,w) = \mathop{\arg \max}_{x} E(x,a,w)\]可以通过梯度方法实现:
\[x^k = x^{k-1}+\frac{α}{2} ▽_xE(x,a,w)+ω^k, ω^k \text{~}N(0,α)\](3)identification
给定$x$和$w$,求$a$。即给定样本点的一个状态和其对应的概念,求哪些样本点满足概念。即求解优化问题:
\[a(w,x) = \mathop{\arg \max}_{a} E(x,a,w)\]可以通过梯度方法实现:
\[a^k = a^{k-1}+\frac{α}{2} ▽_aE(x,a,w)+ω^k, ω^k \text{~}N(0,α)\]模型训练
作者在训练模型时设计了一种few-shot prediction task,即先提供样本的一个状态$x^0$和应该关注的样本点$a$,模型从中学习到概念$w$之后,对于给定的新状态$x^1$求其应关注的样本点$a$(identification任务)或对于给定的应该关注的样本点$a$求其对应的样本状态$x^1$(generation任务):