论文标题 | Generative Causal Explanations for Graph Neural Networks
论文来源 | ICML 2021
论文链接 | https://arxiv.org/abs/2104.06643
源码链接 | https://github.com/wanyu-lin/ICML2021-Gem

TL;DR

论文中提出了一种通用的因果解释方法 Gem,可以为不同图学习任务中的 GNNs 提供通用的可解释性。主要想法是将 GNNs 模型的推理决策可解释性问题转为一个因果学习任务,然后基于格兰杰因果 (Granger causality) 的目标函数来训练一个因果解释模型。Gem 不依赖于 GNNs 的内部结构和图学习任务相关的先验知识,因此泛化性较强而且可以从图结构数据的因果角度来解释 GNNs ~ 保持好奇 🧐。实验部分在生成数据集和真实数据中验证了 Gem 相对其它解释模型而言,可以提升 30% 的解释准确性而且解释速度可达 110×。

Problem Definition

由于 GNNs 模型的可解释性研究是一个新兴的研究领域,当前主流的研究工作如下 Mark 🐶

以上方法都是基于图结构或者加性特征归因 (addittive feature attribution) 的方法来解释 GNNs 的推理结果,泛化性较差而且没有从因果层面来考虑。

而这篇文章从因果模型来解释 GNNs,希望这一篇文章可以解决我多年前对 GNNs 因果推理的疑惑 🤔

首先好奇会以什么形式来解释 GNNs ⁉️

一般的图学习任务定义如下:

给定图集合G={Gi}i=1N,G=N\mathcal{G}=\{G_i\}^N_{i=1}, |\mathcal{G}|=N,每个图表示为Gi=(Vi,Ei)G_i=(V_i, E_i),其节点集合为Vi={v1i,v2i,,vVii}V_i=\{v_1^i, v_2^i,\cdots, v_{|V_i|}^i\},对应的每个节点特征维度为Xi={x1i,x2i,,xVii},xjiRdX_i=\{x_1^i, x_2^i,\cdots, x_{|V_i|}^i\}, x_j^i\in \mathbb{R}^d​。论文中考虑在 Graph-level 和 Node-level 级别的分类任务来解释 GNNs。

  • 对于 Graph-level 分类任务的数据集为D={(Gi,yi)}i=1N\mathcal{D}=\{(G_i, y_i)\}_{i=1}^N​ ,每个图GiG_i​ 其对应的标签为yiY={c1,c2,,cl}y_{i} \in \mathcal{Y}=\left\{c_{1}, c_{2}, \cdots, c_{l}\right\}​,其中ll​​ 表示类别数量。

  • 对于 Node-level 分类任务的数据集为D={(vj,yj)}j=1V\mathcal{D}=\{(v_j, y_j)\}_{j=1}^{|V|}​,每个图GG 中的节点vjVv_j\in V 对应的标签为yiYy_i\in \mathcal{Y}​。

论文中使用示例IiI_i​ 表示一个实例,在 Graph-level 中对应GiG_i​ ,在 Node-level 中对应vjv_j​​​​。所以下文谈到的实例可以为 Graph or Node。

GNNs 模型可以形式化化地表示如下:

  • Graph-level:f():GYf(\cdot): \mathcal{G}\rightarrow \mathcal{Y}

  • Node-level:f():VYf(\cdot): \mathcal{V}\rightarrow \mathcal{Y}

对应的目标函数如下,L:y×y~s\mathcal{L}: y \times \tilde{y} \rightarrow s​​​​,其中yy​​​ 表示真实分类,y~\tilde{y}​​ 表示预测输出,标量sRs\in \mathbb{R}​​​​ 表示对应的损失。

模型解释的任务是给定一个预训练模型f()f(·)​ ,得到模型f()expf(\cdot)_{exp}​​​​​ 对预训练模型进行快速精确的解释,预训练模型被称为 target GNN。

解释下 GNNs 模型的解释形式:

Intrinsically, an explanation is a subgraph that is the most relevant for a prediction —— the outcome of the target GNN.

个人理解:对于预测模型的预测输出,需要找到原图中的对应的子图来支持这一分类结果,找到子图这一过程由解释模型完成。这个子图作用在于?

Algorithm/Model

论文中提出的模型如下图所示

Gem 整体架构

主要包括两个模块:

  • Distillation process:对于图中的边进行因果贡献分数计算。
  • Graph generator:根据 distillation 得到的子图监督训练图生成模型。

因果贡献

GNNs 对于一个实例的预测结果主要在于图结构 computation graphGic=(Vic,Aic,Xic)G_{i}^{c}=\left(V_{i}^{c}, A_{i}^{c}, X_{i}^{c}\right)​​​,其中Aic{0,1}A_i^c\in \{0,1\}​​​ 表示邻接矩阵。

GNN 的目标是学习到一个分类条件分布P(YGic)\mathcal{P}(Y|G_i^c)​​。

一个计算图示例如下,对于节点ii 的 2-hop。

图示例

给定预训练的 GNNs 模型和对应的实例GcG^c ,其对应的分类结果为y~=p(YGc)\tilde{y}=p(Y|G^c),所以解释模型的任务是找到预测结果y~\tilde{y} 对应的子图GsG^s

考虑到格兰杰因果的主要思想,可以量化图中边ejGce_j\in G^c​​ 对 预训练 GNN 模型预测误差δGc\delta_{G^c}​ 的因果贡献。

Δδ,ej=δGc\{ej}δGc\Delta_{\delta, e_{j}}=\delta_{G^{c} \backslash\left\{e_{j}\right\}}-\delta_{G^{c}}

这样就可以计算删除边eje_j 对模型的因果贡献,其预测误差即为损失函数计算得到的值。计算方式如下

y~Gc=f(Gc),y~Gc\{ej}=f(Gc\{ej})δGc=L(y,y~Gc)δGc\{ej}=L(y,y~Gc\{ej})\begin{aligned} \tilde{y}_{G^{c}}&=f\left(G^{c}\right),\\ \tilde{y}_{G^{c} \backslash\left\{e_{j}\right\}}&=f\left(G^{c} \backslash\left\{e_{j}\right\}\right)\\ \delta_{G^{c}} &=\mathcal{L}\left(y, \tilde{y}_{G^{c}}\right) \\ \delta_{G^{c} \backslash\left\{e_{j}\right\}} &=\mathcal{L}\left(y, \tilde{y}_{G^{c} \backslash\left\{e_{j}\right\}}\right) \end{aligned}

利用删除边的策略来计算因果贡献这种方法论文中将之称为ground-truth distillation process.

通过计算图中所有边的因果贡献,可以直接根据贡献分数排序来选择 top-K 最相关的边作为预测解释。但是图数据中的边贡献分布并不是独立的,因此作者使用 graph rules 来提高 distillation 过程。这一步就比较玄学了,完全取决于数据特征,但是大部分应该考虑了 top-K distillation edges 应该是连通的。

以上蒸馏过程中得到的子图表示为G~s=(V~s,A~s,X~s)\tilde{G}^{s} = \left(\tilde{V}^{s}, \tilde{A}^{s}, \tilde{X}^{s}\right)​​,可以根据其它模型的输出来进行有监督的,以此来解释其它模型

因果解释模型

原则上任意的图生成模型都可以用作图因果解释模型,论文中用到的是 Graph auto-encoder。

Z=GCNs(Ac,Xc)A^c=σ(ZZT)\begin{array}{c} Z=\mathbf{G C N s}\left(A^{c}, X^{c}\right) \\ \hat{A}^{c}=\sigma\left(Z Z^{T}\right) \end{array}

其中AcA^c​ 表示 computation graph 的邻接矩阵,A~c\tilde{A}^c 表示每条边对GcG^{c}​ 预测其子图的贡献。

对于目标节点的模型解释输出为 computation graph 的一个 compact subgraph,以此解释一个节点分类为什么得到当前标签。

解释模型直接使用上一步根据因果贡献 top-K 边过滤用到的子图进行训练,损失函数为 RMSE。还使用了 node labeling technique 技术来区分不同的节点,这部分不再细述,了解其整体过程就可以了。

Experiments

实验部分使用了人工和真实数据集,解释准确率如下所示

实验结果

其运行时间对比结果如下

实验结果

以一个例子说明输出子图的效果

实验结果

Thoughts

  • 论文中提出的 Gem 模型整体思路清楚而且模型简单易懂,大佬的文章果然写得就是不一样 👍🏻
  • 整体而言是考虑了因果关系,但是 Explainer 非常依赖因果分数计算而且超参数 top-K…
  • Explainer 用了两层 GCN 是不是智能学到 2-hop 内的因果关联,如果超出这个 hop 重构的邻接矩阵会不是不太靠谱呢
  • 回到模型因果可解释性问题,没想到 GNN 可以用 compact subgraph 来解释模型结果 👍🏻


联系作者