TL;DR
对于变量间因果图结构学习的问题,论文中基于图自编码器和梯度优化的方法来学习观测数据中的因果结构,主要可以解决非线性结构等价问题并且将其应用到向量值形式的变量因果结构预测中。实验部分在人工生成数据中验证了提出的 GAE 模型优于当前其它基于梯度优化的模型 NOTEARS 和 DAG-GNN 等,尤其是在规模较大的因果图预测问题中。此外还测试了模型效率问题,在训练过程中随着图规模增大可以达到线性时间。
Problem Description
对于如何学习变量间因果图结构的问题,当前主流方法主要可以划分为三种类型:
相比于 Constraint- /Score-based 系列的方法,Gradient-based 的方法准确性和计算效率更高,目测解释性差点但奈何图深度学习🔥啊。 因此主要介绍基于梯度优化的因果图学习发展背景。
令 DAGG 表示因果图,包含节点{X1,X2,⋯,Xd} 其中Xi∈Rl,主要考虑的问题是加性噪声模型 (ANM) 下的因果结构学习方法,假设数据生成如下所示:
Xi:=fi(Xpa(i))+Zi,i=1,2,…,d
其中Xpa(i) 表示G 中存在有向边指向变量Xi 的节点集合,fi:R∣Xpa(i)∣×l→Rl 为向量映射函数,Zi∈Rl 表示加性噪声并假设是独立同分布的。集合表示为X:=[X1,X2,⋯,Xd] 和Z:=[Z1,Z2,⋯,Zd]。
NOTEARS 首先将 score-based 系列的组合优化问题转化为的线性结构等价模型(SEM)。对于上述数据生成模型改写为
Xi=AiTX+Zi,i=1,2,…,d
假设Xi,Zi∈R 并且Ai∈Rd 表示系数向量A=[A1,A2,⋯,Ad]∈Rd×d 表示线性 SEM 的加权邻接矩阵。为了保证因果图G 是有向无环的,需要对A 进行限制,
tr(eA⊙A)−d=0
NOTEARS 将最小平方损失函数作为优化目标函数,如下所示
Amin subject to 2n1j=1∑n∥∥∥∥X(j)−ATX(j)∥∥∥∥F2+λ∥A∥1tr(eA⊙A)−d=0
其中X(j) 表示X 的第j 个观测值。从这定义就可以看出 NOTEARS 只能处理单值变量间的因果,而且是线性结构等价模型。这也是 GAE 所优化改进的场景。
DAG-GNN 为了将上述模型适用到非线性场景,提出了的生成式模型如下
X=g2((I−AT)−1g1(Z))
其中g1,g2 表示非线性函数,DAG-GNN 用到的是 MLP + GNN 方法;Z 作为隐变量而且维度可以小于变量数量d。模型设计细节可以参考我另一篇文章 DAG-GNN:基于图神经网络的有向无环图结构表示学习
以上简单介绍了因果图挖掘的背景知识,也是个初学者的综述性介绍,下面进入这篇文章的正文。
Algorithm/Model
这篇论文主要是基于 NOTEARS 方法进行改进,从而使其适用更多场景。主要模型架构如下所示
改进主要包括两部分:非线性因果学习和向量形式变量可用而不仅是标量数据。
对于 NOTEARS 优化的目标函数可以改写为
A,Θmin subject to 2n1j=1∑n∥∥∥∥X(j)−f(X(j),A)∥∥∥∥F2+λ∥A∥1tr(eA⊙A)−d=0
其中f(X(j),A) 表示数据生成模型,对于 NOTEARS 就是线性 SEM 即f(X(j),A)=ATX(j)
为了将f 扩展为非线性的,可以自定义一个非线性的关系映射,例如文章用到
f(X(j),A)=ATg1(X(j))
其中g1:Rl→Rl 为非线性函数可选为 MLP,和 DAG-GNN 想法类似。为了增强非线性就再加一层 MLPg2:Rl→Rl
f(X(j),A)=g2(ATg1(X(j)))
上面的公式一看,不就和 GAE 的计算形式差不多,重写一遍
H(j)H(j)′f(X(j),A)=g1(X(j))=ATH(j)=g2(H(j)′)
如果g1 和g2 分别表示 variable-wise 编码器和解码器,上面的计算形式和优化目标函数不就是基于重构误差训练的 GAE,GAE 不就可以处理 vector-valued 的变量了么 👏
论文中选择两个 variable-wised MLPsg1:Rl→Rl′ 和g2:Rl→Rl′ 其中l′ 表示隐藏层维度。最终的优化函数即为
A,Θ1,Θ2min subject to 2n1j=1∑n∥∥∥∥X(j)−X^(j)∥∥∥∥F2+λ∥A∥1tr(eA⊙A)−d=0
和 DAG-GNN 主要的不同点在于:论文中用的 GAE 是以Xpa(i) 作为输入,而 DAG-GNN 是生成式模型以噪声数据Z 作为输入。
个人感觉就是 GAE 和 VGAE 的区别,GAE PK VGAE 怎么说好呢?作者只能在实验中证明了 GAE 比 VGAE 效果好而且快。🤔
对于上述优化函数,作者采用了增广的拉格朗日乘子法进行求解,其形式如下
Lρ(A,Θ1,Θ2,α)=2n1j=1∑n∥∥∥∥X(j)−X^(j)∥∥∥∥F2+λ∥A∥1+αh(A)+2ρ∣h(A)∣2
其中h(A):=tr(eA⊙A)−d ,α 表示拉格朗日乘子,ρ>0 表示惩罚因子,因此对应的梯度更新规则如下
Ak+1,Θ1k+1,Θ2k+1αk+1ρk+1=A,Θ1,Θ2argminLρk(A,Θ1,Θ2,αk)=αk+ρkh(Ak+1)={βρk,ρk, if ∣∣∣h(Ak+1)∣∣∣≥γ∣∣∣h(Ak)∣∣∣ otherwise
其中 $\beta >1 $ 而且γ<1 表示可调的超参数。
Experiments
实验部分在人工数据集中对比 baselines NOTEARS 和 DAG-GNN。数据包括两种 Scalar-based 和 Vector-Valued。
采用的指标包括两种结构化汉明距离(SHD)和正阳率 (TPR)。实验结果如下所示
整体而言性能提升蛮大的,而且模型训练实验也比较短。
Thoughts
介绍因果图挖掘的初衷是:了解如何用深度学习的方法挖掘变量间的因果图,而不要局限在 PC 或者 kNN 的思路上。
尤其是之前介绍的多变量时间序列关联挖掘,时间序列变量间存在因果关系但之前使用的因果挖掘方法却过于简单粗暴。
图作为一种广义的数据结构,任何存在某种关联的数据都可以使用当前🔥的图机器学习进行建模。目前很多任务都是已知图结构,对于未知图结构的实体就需要使用模型学习到其中关系或者因果,这往往是一个更有挑战的任务。