通俗易懂解释经典模型 GraphSAGE
很久之前在博客中记录了 GraphSAGE 的理论 NIPS 2017 | GraphSAGE:大规模图上的归纳表示学习模型,此时再回首阅读这篇经典的论文发现描述得仍不够通俗易懂,大部分是对论文的翻译工作,对此领域了解较少的同学仍会感到疑惑。因此本文以更加简单清晰的图示方法解释 GraphSAGE 并介绍下用于 Graph-level 和 Node-level 的 inductive learning 任务,最后进行简单总结。
背景
经典的谱图卷积 GCN 和 Node2vec 系列难以解决新节点的表示学习问题而且图是固定的,仅适用于 transductive learning 任务而且不符合现实很多复杂网络的场景,因此提出了 inductive learning 框架 GraphSAGE (SAmple and aggreGatE) 对未知节点生成节点 embedding。GraphSAGE 的主要思想是:学习特征映射函数而不是直接为每个节点学习 embedding,学习映射函数的好处是可以通过一个节点的局部邻居采样并聚合节点特征来生成节点表示。
具体理论和 motivation 不再细述,可以参考我博客中对该论文的解释 👉🏻 NIPS 2017 | GraphSAGE:大规模图上的归纳表示学习模型 ,👇🏻 下面以图示的方法来解释 GraphSAGE 的想法及其解决问题的思路。
特征聚合过程
GraphSAGE 主要思想是给定节点然后根据其邻居节点聚合特征,以一个简单输入图说明其过程。
以聚合节点 A 特征为例,上图右侧图片中 Black Box 即表示 A、B、C、D 的聚合函数(Aggregator function),聚合函数可以有很多种类型例如 mean、sum 等。上图仅表示根据 1-hop 节点邻居来聚合特征,那么节点如何融合 k-hop 的节点特征呢,下面直接以示例说明其过程。
给定初始化节点特征如下
例如使用 mean 聚合函数,第一步:根据节点及其邻居聚合特征迭代计算过程如下
第一步迭代后节点特征如下
重复上一步,第二步迭代计算得到的结果如下所示
分析以上迭代聚合过程。以 表示节点 A 的初始 embedding 值 0.1, 表示经过一次迭代后的值 0.25,其它节点类似表示。
那么以上计算过程第一步可以表示为
第二步迭代
都以节点初始值表示那么就是
从上面的式子即可看出可以通过 k 次聚合即可融合 k-hop 邻居特征,但实验中表明 2-hop 即可得到非常好的效果。👍🏻
上面以简单示例说明了节点的 k-hop 聚合过程,下面转到 GraphSAGE 基础的理论。
以函数 表示聚合函数,以 步聚合操作泛化形式如下
以 表示节点 的邻居节点,那么,上述表达式可以表示为
为了与 GraphSAGE 中符号统一,以 表示;聚合函数 可以包括多种类型,每层对应的聚合函数表示为;设定总迭代聚合次数为 且图表示为。
此处附上基于 GNN 的推荐系统 PinSAGE 中的聚合示意图,以此更清晰描述其过程。
GraphSAGE
根据上述示例的聚合过程,数学表示如下
其中 表示经过 次聚合后的节点 embedding。
如果仅用以上简单的聚合过程那么肯定会导致 over-smoothing 的问题,即节点特征经过多次融合后难以区分,例如节点 E,F 邻居相同。
那么 GraphSAGE 是如何解决这个问题的呢?GraphSAGE 先聚合邻居节点的特征然后再与该节点的特征进行 concate,从而保留了节点原始特征进行节点区分
接下来添加些非线性函数转换 获得更强的特征表达:
可以看到特征聚合过程都不存在任何训练参数, 单纯的为模型增加些可训练参数。
此外,在每次迭代后需要给每个节点进行 L2 标准化,就得到了整体的 GraphSAGE 聚合算法,如下所示
上述算法描述了前向传播过程,需要定义目标函数进行训练:
对于无监督学习,目标函数如下
其中 表示负样本数量, 表示负采样概率分布;从上式可以看到损失函数令相邻节点具有相似表示,而不相邻节点表示不同。
对于监督学习可以根据任务来设定损失函数,例如节点分类使用交叉熵损失函数。
聚合函数
在图中顶点的邻居是无序的,所以希望构造出的聚合函数是对称的(即也就是对它输入的各种排列,函数的输出结果不变),同时具有较高的表达能力。 聚合函数的对称性(symmetry property)确保了神经网络模型可以被训练且可以应用于任意顺序的顶点邻居特征集合上。
论文中列举了三种聚合函数:
Mean aggregator
将目标顶点和邻居顶点的第 层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第 层表示向量。
LSTM aggregator
文中提出了一个基于 LSTM 的复杂的聚合器。和均值聚合器相比,LSTMs 有更强的表达能力。但 LSTMs 不是对称的,因为它们以一个序列的方式处理输入。因此,需要先对邻居节点随机顺序,然后将邻居序列的 embedding 作为 LSTM 的输入。
Pooling aggregator
Pooling 聚合器既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的 embedding 向量进行一次非线性变换,之后进行一次 pooling 操作,将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第 层表示向量。
以上三种聚合器各有优势,但是对于不同的任务还可以采用 Sum aggregator,Max aggregator 等。
Inductive capability
文章开头就强调了 GraphSAGE 的 Inductive learning 优势:可以使用图中节点子集训练模型然后用于其它节点,那么他是怎么做到的呢?参数共享
当训练好两层 Aggregator layers 后得到权重 和 时,可以通过参数共享用于其它图结构。👍🏻
尤其是对于 PPI 网络,可以在根据组织结构图 训练好模型参数后,可以在新组织结构 中应用来生成节点的 embedding,如下图所示
除了能对 new graph 生节点嵌入,对于 now node 同样可以生成 embedding,参数共享的思路是相同的。但是新增加节点后尽可能重新训练对旧节点的 embedding 进行更新,如果是大规模的图新加几个节点对结果影响不大。
总结
通过以上介绍,总结下 GraphSAGE 的优点和缺点:
优点
通过邻居采样的方式解决了GCN内存爆炸的问题,适用于大规模图的表示学习;
将 transductive 转化为 inductive learning,而且支持增量特征;
引入邻居采样,可有效防止训练过拟合,增强泛化能力;
可以根据不同领域的图场景来自定义图聚合方式;
缺点
无法处理加权图,仅可以邻居节点等权聚合;
邻居采样引入随机过程,推理阶段同一节点 embedding 特征不稳定,且邻居采样会导致反向传播时梯度不稳定;
邻居采样数目限制会导致部分节点的重要局部信息丢失;
GCN 网络层太多会引起 over-smoothing 问题,当然后续非常多的工作来改进这个问题。