ICML 2021 | 针对图神经网络的通用因果解释方法
论文标题 | Generative Causal Explanations for Graph Neural Networks
论文来源 | ICML 2021
论文链接 | https://arxiv.org/abs/2104.06643
源码链接 | https://github.com/wanyu-lin/ICML2021-Gem
TL;DR
论文中提出了一种通用的因果解释方法 Gem,可以为不同图学习任务中的 GNNs 提供通用的可解释性。主要想法是将 GNNs 模型的推理决策可解释性问题转为一个因果学习任务,然后基于格兰杰因果 (Granger causality) 的目标函数来训练一个因果解释模型。Gem 不依赖于 GNNs 的内部结构和图学习任务相关的先验知识,因此泛化性较强而且可以从图结构数据的因果角度来解释 GNNs ~ 保持好奇 🧐。实验部分在生成数据集和真实数据中验证了 Gem 相对其它解释模型而言,可以提升 30% 的解释准确性而且解释速度可达 110×。
Problem Definition
由于 GNNs 模型的可解释性研究是一个新兴的研究领域,当前主流的研究工作如下 Mark 🐶
GNNExplainer: [NeurIPS 2019] Generating Explanations for Graph Neural Networks
PGExplainer:[NeurIPS 2020] Parameterized Explainer for Graph Neural Network
PGM-Explainer: [NeurIPS 2020] Probabilistic Graphical Model Explanations for Graph Neural Networks
XGNN:[KDD 2020] Towards Model-Level Explanations of Graph Neural Networks
以上方法都是基于图结构或者加性特征归因 (addittive feature attribution) 的方法来解释 GNNs 的推理结果,泛化性较差而且没有从因果层面来考虑。
而这篇文章从因果模型来解释 GNNs,希望这一篇文章可以解决我多年前对 GNNs 因果推理的疑惑 🤔
首先好奇会以什么形式来解释 GNNs ⁉️
一般的图学习任务定义如下:
给定图集合,每个图表示为,其节点集合为,对应的每个节点特征维度为。论文中考虑在 Graph-level 和 Node-level 级别的分类任务来解释 GNNs。
对于 Graph-level 分类任务的数据集为 ,每个图 其对应的标签为,其中 表示类别数量。
对于 Node-level 分类任务的数据集为,每个图 中的节点 对应的标签为。
论文中使用示例 表示一个实例,在 Graph-level 中对应 ,在 Node-level 中对应。所以下文谈到的实例可以为 Graph or Node。
GNNs 模型可以形式化化地表示如下:
Graph-level:
Node-level:
对应的目标函数如下,,其中 表示真实分类, 表示预测输出,标量 表示对应的损失。
模型解释的任务是给定一个预训练模型 ,得到模型 对预训练模型进行快速精确的解释,预训练模型被称为 target GNN。
解释下 GNNs 模型的解释形式:
Intrinsically, an explanation is a subgraph that is the most relevant for a prediction —— the outcome of the target GNN.
个人理解:对于预测模型的预测输出,需要找到原图中的对应的子图来支持这一分类结果,找到子图这一过程由解释模型完成。这个子图作用在于?
Algorithm/Model
论文中提出的模型如下图所示
主要包括两个模块:
- Distillation process:对于图中的边进行因果贡献分数计算。
- Graph generator:根据 distillation 得到的子图监督训练图生成模型。
因果贡献
GNNs 对于一个实例的预测结果主要在于图结构 computation graph,其中 表示邻接矩阵。
GNN 的目标是学习到一个分类条件分布。
一个计算图示例如下,对于节点 的 2-hop。
给定预训练的 GNNs 模型和对应的实例 ,其对应的分类结果为,所以解释模型的任务是找到预测结果 对应的子图。
考虑到格兰杰因果的主要思想,可以量化图中边 对 预训练 GNN 模型预测误差 的因果贡献。
这样就可以计算删除边 对模型的因果贡献,其预测误差即为损失函数计算得到的值。计算方式如下
利用删除边的策略来计算因果贡献这种方法论文中将之称为ground-truth distillation process.
通过计算图中所有边的因果贡献,可以直接根据贡献分数排序来选择 top-K 最相关的边作为预测解释。但是图数据中的边贡献分布并不是独立的,因此作者使用 graph rules 来提高 distillation 过程。这一步就比较玄学了,完全取决于数据特征,但是大部分应该考虑了 top-K distillation edges 应该是连通的。
以上蒸馏过程中得到的子图表示为,可以根据其它模型的输出来进行有监督的,以此来解释其它模型
因果解释模型
原则上任意的图生成模型都可以用作图因果解释模型,论文中用到的是 Graph auto-encoder。
其中 表示 computation graph 的邻接矩阵, 表示每条边对 预测其子图的贡献。
对于目标节点的模型解释输出为 computation graph 的一个 compact subgraph,以此解释一个节点分类为什么得到当前标签。
解释模型直接使用上一步根据因果贡献 top-K 边过滤用到的子图进行训练,损失函数为 RMSE。还使用了 node labeling technique 技术来区分不同的节点,这部分不再细述,了解其整体过程就可以了。
Experiments
实验部分使用了人工和真实数据集,解释准确率如下所示
其运行时间对比结果如下
以一个例子说明输出子图的效果
Thoughts
- 论文中提出的 Gem 模型整体思路清楚而且模型简单易懂,大佬的文章果然写得就是不一样 👍🏻
- 整体而言是考虑了因果关系,但是 Explainer 非常依赖因果分数计算而且超参数 top-K…
- Explainer 用了两层 GCN 是不是智能学到 2-hop 内的因果关联,如果超出这个 hop 重构的邻接矩阵会不是不太靠谱呢
- 回到模型因果可解释性问题,没想到 GNN 可以用 compact subgraph 来解释模型结果 👍🏻