C&S:标签传播思想与浅层模型相结合性能即可超过图神经网络
论文标题 | Combining Label Propagation and Simple Models Out-performs Graph Neural Networks
论文来源 | ICLR 2021
论文链接 | https://arxiv.org/abs/2010.13993
源码链接 | https://github.com/CUAI/CorrectAndSmooth
TL;DR
图神经网络 (GNNs) 在图表示学习领域盛极一时,但是对为什么 GNNs 有效或者其对不同任务性能提升的必然性知之甚少。这篇文章通过大量的直推式实验(节点分类任务)证明,通过浅层模型和两个基于标签传播的后处理步骤即可达到当前 GNNs 模型的性能 🤭,后处理步骤包括两步 (i) 误差修正 (error correlation):利用训练数据中的残差来纠正测试数据中的误差 (ii) 预测修正 (prediction correlation):平滑测试数据中的预测结果,作者将其提出的整个框架总称为 C&S(correct and Smooth),并且在实验中证明 C&S 不仅准确性超过当前主流的 GNNs 模型,而且参数量和运行时间远远低于复杂的 GNNs 结构。
Algorithm/Model
针对图中节点分类任务,论文中提出的简单处理框架如下图所示,
主要包含三个部分:
- 基础预测模型:仅依赖节点特征并且忽略图的结构,例如 MLP 或者线性模型;
- 修正步骤:将训练数据中的不确定性传播到整个图以此来修正基础预测结果;
- 平滑步骤:对节点预测结果进行平滑。
其中修正步骤和平滑步骤是基于半监督学习的标签传播思想进行改进的,整个框架没有利用图结构来学习模型参数因此参数量非常少而且不需要大量的时间来训练模型,但实验效果却非常好,所以愈发感觉神经网络是一门玄学了啊!
下面详细介绍下 C&S 模型的细节。
C&S Model
假设给定无向图 ,其中节点数量,节点特征,邻接矩阵,度矩阵, 归一化邻接矩阵。对于节点预测任务,节点集 划分为不带标签节点集 和带标签节点集,将标签表示为独热编码矩阵,带标签的节点划分为训练集和验证集。直推式的节点分类任务就是在给定的 下对于集合 $ U$ 中的节点预测标签。
Base Predictor
首先使用一个不依赖于图结构的基础预测模型 基于节点特征来预测节点分类,优化目标函数是
其中, 是损失函数,论文中使用的 为线性模型或者浅层 MLP, 是交叉熵损失函数, 中的样本用于调参。通过 可以得到每个节点的基础预测结果,每一行表示 softmax 后节点的分布概率。其实这一步也可以使用 GNNs 来作为 base predictor,但是为了更好地比对效果作者只用了浅层模型。
Error Correlation
这一步骤的主要任务是对 base predictor 中的结果提高准确性,主要思想是希望 base predictor 中的误差是通过图中的边进行传播的,即节点 和邻居节点有相似的误差。论文中通过通过残差传播来实现这种不确定性。
首先定义一个误差矩阵,误差值是训练集中的残差和零:
对于训练集中已知标签的节点只有基础预测结果完全正确才为 0,因此文中通过标签传播技术来平滑误差
上式中第一项是为了平滑整个图中的误差,等于,其中是 的第 列。第二项是为了使解接近初始值。上式的解为,其中 并且,迭代求解上式直至收敛得到。
上式是通过迭代来修正和平滑误差,因此对 base predictor 的结果修正为
由于上述的误差传播求解方法仅适用于回归问题而对于节点分类任务不适用,即 的尺寸不同,因此论文中提出两种变形方法来满足条件。
Autoscale
由于训练集只有已知标签的节点预测误差是已知的,因此需要通过平均误差来近似未知标签的误差。
对于,定义,对于未知标签节点
Scaled Fixed Diffusion (FDiff-scale)
通过固定已知误差矩阵来进行迭代直至收敛得到。
其中 。
Smooth Predictions
通过上述两步我们得到了,为了得到最终的结果需要对修正的预测进行平滑,主要思想就是利用标签传播方法对结果优化。根据标签初始化最终结果:
然后迭代收敛得到最终的预测概率矩阵 :
其中,最后节点分类结果为。
Experiments
实验中的都是直推式的节点分类任务,使用的数据集如下
整体性能优于当前 baselines,只使用训练集数据标签的结果如下:
同时使用训练集和测试集的节点标签结果如下:
文中还考虑了不同 base predictor 对 C&S 效果的提升,
论文中还有时间对比实验和可视化效果,感兴趣的可以参考原文。
Thoughts
论文中的思想很 naive 但性能又非常好。在每篇论文都想着怎么把模型搞复杂发论文的情况下,作者却可以用简单的模型来达到 SOTA,这看样子就是大佬的厉害之处 👍
但论文方法的缺点也蛮明显的:
- 适用于半监督学习;
- 直推式学习,仅考虑了节点的分类任务,链路预测和图级别的任务都可能不适用;
相信后续一定会有人对其进行优化!
突然冒出个想法如果仅用节点特征和标签传播算法来做无监督或者半监督学习,会在 OGB 中达到什么效果呢?🤔