基于 Graph 的 Embedding 方法概述
Graph Embedding
基于内容的 Embedding 方法(如 word2vec、BERT 等)都是针对“序列”样本(如句子、用户行为序列)设计的,但在互联网场景下,数据对象之间更多呈现出图结构,如下图所示 (1) 有用户行为数据生成的物品关系图;(2) 有属性和实体组成的只是图谱。
对于图结构数据,基于内容的 embedding 方法不太好直接处理了。因此,为了解决图结构数据的问题,Graph Embedding 开始得到大家的重视,并在各个领域进行尝试;
Graph Embedding 是一种将图结构数据映射为低微稠密向量的过程,从而捕捉到图的拓扑结构、顶点与顶点的关系、以及其他的信息。目前,Graph Embedding 方法大致可以分为两大类:
- 浅层图模型;
- 深度图模型。
浅层图模型
浅层图模型主要是采用random-walk + skip-gram
模式的 embedding 方法。主要是通过在图中采用随机游走策略来生成多条节点列表,然后将每个列表相当于含有多个单词(图中的节点)的句子,再用 skip-gram 模型来训练每个节点的向量。这些方法主要包括DeepWalk
、Node2vec
、Metapath2vec
等。
DeepWalk
DeepWalk 是第一个将 NLP 中的思想用在 Graph Embedding 上的算法,输入是一张图,输出是网络中节点的向量表示,使得图中两个点共有的邻居节点(或者高阶邻近点)越多,则对应的两个向量之间的距离就越近。
DeepWalk 得本质可以认为是:random walk + skip-gram。在 DeepWalk 算法中,需要形式化定义的是 random walk 的跳转概率,即到达节点后,下一步遍历其邻居节点的概率:
其中, 表示节点的所有出边连接的节点集合,表示由节点 连接至节点 的边的权重。由此可见,原始 DeepWalk 算法的跳转概率是跳转边的权重占所有相关出边权重之和的比例。算法具体步骤如下图所示:
DeepWalk 算法原理简单,在网络标注顶点很少的情况也能得到比较好的效果,且具有较好的可扩展性,能够适应网络的变化。但由于 DeepWalk 采用的游走策略过于简单(BFS),无法有效表征图的节点的结构信息。
Node2vec
为了克服 DeepWalk 模型的 random walk 策略相对简单的问题,斯坦福大学的研究人员在 2016 年提出了 Node2vec 模型。该模型通过调整 random walk 权重的方法使得节点的 embedding 向量更倾向于体现网络的同质性或结构性。
同质性:指得是距离相近的节点的 embedding 向量应近似,如下图中,与节点 相连的节点的 embedding 向量应相似。为了使 embedding 向量能够表达网络的同质性,需要让随机游走更倾向于 DFS,因为 DFS 更有可能通过多次跳转,到达远方的节点上,使游走序列集中在一个较大的集合内部,使得在一个集合内部的节点具有更高的相似性,从而表达图的同质性。
结构性:结构相似的节点的 embedding 向量应近似,如下图中,与节点 结构相似的节点 的 embedding 向量应相似。为了表达结构性,需要随机游走更倾向于 BFS,因为 BFS 会更多的在当前节点的邻域中游走,相当于对当前节点的网络结构进行扫描,从而使得 embedding 向量能刻画节点邻域的结构信息。
在 Node2vec 中,同样是通过控制节点间的跳转概率来控制 BFS 和 DFS 倾向性的。如下图所示,当算法先由节点 跳转到节点,准备从节点 跳转至下一个节点时,各节点概率定义如下:
其中,是节点和边的权重, 定义如下:
表示节点 与 的最短路径,如 与 的最短路径为 1。作者引入了两个参数和来控制游走算法的 BFS 和 DFS 倾向性:
- return parameter:值越小,随机游走回到节点的概率越大,最终算法更注重表达网络的结构性
- In-out parameter:值越小,随机游走到远方节点的概率越大,算法更注重表达网络的同质性
当 时,Node2vec 退化成了 DeepWalk 算法。
下图是作者通过调整 和,使 embedding 向量更倾向于表达同质性和结构性的可视化结果:
从图中可以看出,同质性倾向使相邻的节点相似性更高,而结构性相似使得结构相似的节点具有更高的相似性。
Node2vec 的算法步骤如下:
相较于 DeepWalk,Node2vec 通过设计biased-random walk
策略,能对图中节点的结构相似性和同质性进行权衡,使模型更加灵活。但与 DeepWalk 一样,Node2vec 无法指定游走路径,且仅适用于解决只包含一种类型节点的同构网络,无法有效表示包含多种类型节点和边类型的复杂网络。
Metapath2vec
为了解决 Node2vec 和 DeepWalk 无法指定游走路径、处理异构网络的问题,Yuxiao Dong 等人在 2017 年提出了 Metapath2vec 方法,用于对异构信息网络(Heterogeneous Information Network, HIN)的节点进行 embedding。
Metapath2vec 总体思想跟 Node2vec 和 DeepWalk 相似,主要是在随机游走上使用基于 meta-path 的 random walk 来构建节点序列,然后用 Skip-gram 模型来完成顶点的 Embedding。
异构网络(Heterogeneous Network)的定义如下:
异构网络 其中节点和边的映射函数为。即,存在多种类型节点或边的网络为异构网络。
虽然节点类型不同,但是不同类型的节点会映射到同一个特征空间。由于异构性的存在,传统的基于同构网络的节点向量化方法很难有效地直接应用在异构网络上。
为了解决这个问题,作者提出了 meta-path-based random walk:通过不同meta-path scheme
来捕获不同类型节点之间语义和结构关系。meta-path scheme 定义如下:
其中 表示不同类型节点 和 之间的关系。节点的跳转概率为:
其中,,表示节点的 类型的邻居节点集合。meta-path 的定义一般是对称的,比如 user-item-tag-item-user
。最后采用 skip-gram 来训练节点的 embedding 向量:
其中: 表示节点的上下文中,类型为 的节点,
通过分析 metapath2vec 目标函数可以发现,该算法仅在游走是考虑了节点的异构性,但在 skip-gram 训练时却忽略了节点的类型。为此,作者进一步提出了 metapath2vec++算法,在 skip-gram 模型训练时将同类型的节点进行 softmax 归一化:
metaptah2vec 和 metapath2vec++的 skip-gram 模型结构如下图所示:
metapath2vec++具体步骤如下图所示:
深度图模型
上一节讲的浅层图模型方法在世纪应用中是先根据图的结构学习每个节点的 embedding 向量,然后再讲得到的 embedding 向量应用于下游任务重。然而,embedding 向量和下游任务是分开学习的,也就是说学得的 embedding 向量针对下游任务来说不一定是最优的。为了解决这个 embedding 向量与下游任务的 gap,研究人员尝试讲深度图模型是指将图与深度模型结合,实现 end-to-end 训练模型,从而在图中提取拓扑图的空间特征。主要分为四大类:
- Graph Convolution Networks (GCN)
- Graph Attention Networks (GAT)
- Graph AutoEncoder (GAE)
- Graph Generative Networks (GGN)
本节主要简单介绍 GCN 中的两个经典算法:1)基于谱的 GCN (GCN);2)基于空间的 GCN (GraphSAGE)。
提取拓扑图的空间特征的方法主要分为两大类:1)基于空间域或顶点域 spatial domain(vertex domain) 的;2)基于频域或谱域 spectral domain 的。通俗点解释,空域可以类比到直接在图片的像素点上进行卷积,而频域可以类比到对图片进行傅里叶变换后,再进行卷积。
- 基于 spatial domain:基于空域卷积的方法直接将卷积操作定义在每个结点的连接关系上,跟传统的卷积神经网络中的卷积更相似一些。主要有两个问题:1)按照什么条件去找中心节点的邻居,也就是如何确定 receptive field;2)按照什么方式处理包含不同数目邻居的特征。
- 基于 spectral domain:借助卷积定理可以通过定义频谱域上的内积操作来得到空间域图上的卷积操作。
GCN
理论参考以下文章:
GraphSAGE
GraphSAGE(Graph SAmple and aggreGatE)是基于空间域方法,其思想与基于频谱域方法相反,是直接在图上定义卷积操作,对空间上相邻的节点上进行运算。其计算流程主要分为三步:
- 对图中每个节点领据节点进行采样
- 根据聚合函数聚合邻居节点信息(特征)
- 得到图中各节点的 embedding 向量,供下游任务使用
GraphSAGE 生成 Embedding 向量过程如下:
其中 表示每个节点能够聚合的邻居节点的跳数(例如 时,每个顶点可以最多根据其 2 跳邻居节点的信息来表示自身的 embedding 向量)。算法直观上是在每次迭代中,节点聚合邻居信息。随着不断迭代,节点得到图中来自越来越远的节点信息。
邻居节点采样:在每个 epoch 中,均匀地选取固定大小的邻居数目,每次迭代选取不同的均匀样本。
GraphSAGE 的损失函数如下:
其中,和表示节点和的 embedding 向量,是固定长度的邻居节点, 是 sigmoid 函数,和分别表示负样本分布和数目。
对于聚合函数的,由于在图中节点的邻居是无序的,聚合函数应是对称的(改变输入节点的顺序,函数的输出结果不变),同时又具有较强的表示能力。主要有如下三大类的聚合函数:
- Mean aggretator:将目标节点和其邻居节点的第 k-1 层向量拼接起来,然后对计算向量的 element-wise 均值,最后通过对均值向量做非线性变换得到目标节点邻居信息表示:
- Pooling aggregator:先对目标节点的邻居节点向量做非线性变换并采用 pooling 操作(maxpooling 或 meanpooling)得到目标节点的邻居信息表示:
- LSTM aggretator:使用 LSTM 来 encode 邻居的特征,为了忽略邻居之间的顺序,需要将邻居节点顺序打乱之后输入到 LSTM 中。LSTM 相比简单的求平均和 Pooling 操作具有更强的表达能力。
后续。…
总结
在实际过程中,不同的向量化方法得到的 embedding 结果也会有较大差异,需要根据具体业务需求来选择相应的算法。如要挖掘用户与用户的同质性,可以尝试采用 Node2vec;此外,如果需要结合物品或 Item 的 side-info,可以考虑 GraphSAGE 算法来对图中节点进行 embedding。