论文标题 | DAG-GNN:DAG Structure Learning with Graph Neural Networks
论文来源 | ICML 2019
文章链接:https://arxiv.org/abs/1904.10098
源码链接:https://github.com/fishmoon1234/DAG-GNN

TL;DR

论文中提出一种新的DAG编码架构 DAG-GNN,其实模型的本质就是一个图变分自编码器,模型的优点是既能处理连续型变量又能处理离散型变量;在人工数据集和真实数据集中验证了模型结果可以达到全局最优 🤔;

Model / Algorithm

论文中的整体模型架构如下:

Linear Structural Equation Model

论文中首先通过生成模型来泛化线性结构等价模型;假设ARm×mA \in \mathbb{R}^{m \times m} 表示DAG的加权邻接矩阵,XRm×dX \in \mathbb{R}^{m \times d} 表示每个节点的特征,那么线性模型的的编码方式为:

X=ATX+Z(1)X=A^{T} X+Z \quad\quad\quad(1)

其中ZRm×dZ \in \mathbb{R}^{m \times d} 表示噪声矩阵;如果图中节点是以拓扑序排列的,那么矩阵AA 是一个严格的上三角矩阵,因此DAG中的 ancestral sampling 等价于三角等式的解:

X=(IAT)1Z(2)X=\left(I-A^{T}\right)^{-1} Z \quad\quad\quad(2)

Proposed Graph Neural Network Model

上述等式 (2) 可以写为X=fA(Z)X=f_A(Z),可以表示为数据节点特征ZZ 并得到embeddingXX。传统的GCN 架构计算公式如下:

X=A^ReLU(A^ZW1)W2X=\widehat{A} \cdot \operatorname{ReLU}\left(\widehat{A} Z W^{1}\right) \cdot W^{2}

由于公式 (2) 的特殊结构,因此提出新的图神经网络架构,注意这是解码器的结构

X=f2((IAT)1f1(Z))(3)X=f_{2}\left(\left(I-A^{T}\right)^{-1} f_{1}(Z)\right)\quad\quad\quad(3)

其中f1,f2f_1, f_2 表示Z,XZ, X 的非线性的转换函数;

Model Learning with Variational Autoencoder

对于给定的分布ZZ 和样本X1,,XnX^1, \cdots, X^n,生成模型的目标是最大化对数函数:

1nk=1nlogp(Xk)=1nk=1nlogp(XkZ)p(Z)dZ\frac{1}{n} \sum_{k=1}^{n} \log p\left(X^{k}\right)=\frac{1}{n} \sum_{k=1}^{n} \log \int p\left(X^{k} \mid Z\right) p(Z) d Z

由于上式难以解决因此使用变分贝叶斯;

使用变分后验概率q(ZX)q(Z|X) 来近似实际后验概率q(ZX)q(Z|X)。网络优化的结果是 ELBO(the evidence lower bound)

LELBO=1nk=1nLELBOkL_{\mathrm{ELBO}}=\frac{1}{n} \sum_{k=1}^{n} L_{\mathrm{ELBO}}^{k}

其中

LELBOkDKL(q(ZXk)p(Z))+Eq(ZXk)[logp(XkZ)]\begin{array}{r} L_{\mathrm{ELBO}}^{k} \equiv-D_{\mathrm{KL}}\left(q\left(Z \mid X^{k}\right) \| p(Z)\right) \\ \quad+\mathrm{E}_{q\left(Z \mid X^{k}\right)}\left[\log p\left(X^{k} \mid Z\right)\right] \end{array}

基于 (3)式的解码器结构,对应的编码器结构为

Z=f4((IAT)f3(X))(5)Z=f_{4}\left(\left(I-A^{T}\right) f_{3}(X)\right) \quad\quad\quad(5)

其中f4,f3f_4, f_3 表示f2,f1f_2,f_1 的逆函数。

Loss Function

对于编码器,使用MLP表示f3f_3和恒等映射表示f4f_4,变分后验概率q(ZX)q(Z|X) 是一个因子高斯分布均值MZRm×dM_Z\in \mathbb{R}^{m\times d} 标准差SZRm×dS_Z\in \mathbb{R}^{m\times d},可以通过编码器来进行计算:

[MZlogSZ]=(IAT)MLP(X,W1,W2)(6)\left[M_{Z} \mid \log S_{Z}\right]=\left(I-A^{T}\right) \operatorname{MLP}\left(X, W^{1}, W^{2}\right)\quad\quad\quad(6)

其中MLP(X,W1,W2):=ReLU(XW1)W2\operatorname{MLP}\left(X, W^{1}, W^{2}\right):=\operatorname{ReLU}\left(X W^{1}\right) W^{2}

对于生成模型,使用恒等映射表示f1f_1 MLP来表示f2f_2,得到的似然p(XZ)p(X | Z) 符合高斯分布均值为MXRm×dM_X\in \mathbb{R}^{m\times d} 标准差为SXRm×dS_X\in \mathbb{R}^{m\times d},解码器的计算公式如下:

[MXlogSX]=MLP((IAT)1Z,W3,W4)(7)\left[M_{X} \mid \log S_{X}\right]=\operatorname{MLP}\left(\left(I-A^{T}\right)^{-1} Z, W^{3}, W^{4}\right)\quad\quad\quad(7)

基于公式(6)(7),式(4)中的KL散度项为:

DKL(q(ZX)p(Z))=12i=1mj=1d(SZ)ij2+(MZ)ij22log(SZ)ij1\begin{array}{l} D_{\mathrm{KL}}(q(Z \mid X) \| p(Z))= \\\\ \quad \frac{1}{2} \sum_{i=1}^{m} \sum_{j=1}^{d}\left(S_{Z}\right)_{i j}^{2}+\left(M_{Z}\right)_{i j}^{2}-2 \log \left(S_{Z}\right)_{i j}-1 \end{array}

重构准确率项为:

Eq(ZX)[logp(XZ)]1Ll=1Li=1mj=1d(Xij(MX(l))ij)22(SX(l))ij2log(SX(l))ijc\begin{array}{c} \mathrm{E}_{q(Z \mid X)}[\log p(X \mid Z)] \approx \\\\ \frac{1}{L} \sum_{l=1}^{L} \sum_{i=1}^{m} \sum_{j=1}^{d}-\frac{\left(X_{i j}-\left(M_{X}^{(l)}\right)_{i j}\right)^{2}}{2\left(S_{X}^{(l)}\right)_{i j}^{2}}-\log \left(S_{X}^{(l)}\right)_{i j}-c \end{array}

对于不同类型变量的处理论文中使用了不同的结构,详细参考原文推导过程。

Experiments

人工数据集

联系作者