论文标题 | Temporal Graph Networks for Deep Learning on Dynamic Graphs
论文来源 | ICML 2020
论文链接 | https://arxiv.org/abs/2006.10637
源码链接 | https://github.com/twitter-research/tgn

TL;DR

目前大多数图神经网络都不能对动态图(即节点特征和边特征随着时间而变化的图)进行表征,论文中结合记忆模块和图卷积操作提出了一种通用高效的动态图模型,并且证明前期研究中很多的动态图模型仅是论文中提出框架的一种特例。在实验部分验证了 TGNs 在直推式学习和归纳式学习任务中达到了 state-of-the-art 性能。

Algorithm/Model

Background

📢 静态图卷积不再赘述,如果缺少背景知识的可以参考博主其它文章。

主流的动态图模型可以分为两种:

  • 离散时间的动态图:根据固定时间间隔采集的图快照;
  • 连续时间的动态图:随着时间事件演变的图,事件包括边增加或者消失,节点增加或者消失,节点和边特征变化等。

这两种类型的动态图包括了大多数场景,之前我的理解大多局限在离散时间的动态图,因为连续时间的动态图大部分在社交网络的场景中,这种类型的任务处理的较少。这篇文章提出的框架是处理连续时间的动态图。

文中将随着时间变化的事件时序图建模为G={x(t1),x(t2),}\mathcal{G}=\left\{x\left(t_{1}\right), x\left(t_{2}\right), \ldots\right\} ,其中事件x(t)x(t) 分为两种类型:

  • 节点级别的vi(t)v_i(t) :其中ii表示节点序号,vv 表示节点特征。如果节点存在那么更新对应节点特征,如果节点不存在那么增加节点及其特征。
  • 边级别的事件eij(t)e_{ij}(t):表示出现一条时间相关的有向边,由于边类型可能不止一种因此图实际上是一个多重图。

论文中对删除节点或边的情况单独作为一项进行考虑,比较复杂因此作为附录另加说明。

Temporal Graph Network

TGN 模型的任务是根据随时间变化的连续事件来生成每个事件tt 的图节点表示Z(t)=(z1(t),,zn(t)(t))\mathbf{Z}(t)=\left(\mathbf{z}_{1}(t), \ldots, \mathbf{z}_{n(t)}(t)\right)

论文中首先使用 Memory Module 来保留节点长期的特征,类似于 LSTM 的思路。当一个新的事件来临时模块内节点特征更新的方法如下:

Memory Module

依赖于以下几个模块进行消息传递的节点嵌入。

Message Function

对于tt 时刻节点iijj 间边级别的交互事件,节点信息更新传递函数如下:

mi(t)=msgs(si(t),sj(t),Δt,eij(t)),mj(t)=msgd(sj(t),si(t),Δt,eij(t))\mathbf{m}_{i}(t)=\mathrm{msg}_{\mathrm{s}}\left(\mathbf{s}_{i}\left(t^{-}\right), \mathbf{s}_{j}\left(t^{-}\right), \Delta t, \mathbf{e}_{i j}(t)\right), \quad \mathbf{m}_{j}(t)=\mathrm{msg}_{\mathrm{d}}\left(\mathbf{s}_{j}\left(t^{-}\right), \mathbf{s}_{i}\left(t^{-}\right), \Delta t, \mathbf{e}_{i j}(t)\right)

对于tt 时刻节点ii 的节点级别事件,节点信息更新函数如下:

mi(t)=msgn(si(t),t,vi(t))\mathbf{m}_{i}(t)=\operatorname{msg}_{n}\left(\mathbf{s}_{i}\left(t^{-}\right), t, \mathbf{v}_{i}(t)\right)

其中si(t)\mathbf{s}_{i}\left(t^{-}\right) 表示tt 时刻前的节点特征,msgn\operatorname{msg}_{n} 表示可学习的信息传递函数例如 MLPs 等。

Message Aggregator

由于批事件处理可能会导致同一节点在时刻tt 同时需要更新节点,因此论文中使用一种聚合方式来聚合每个节点的特征。对于时间范围t1,,tbtt_1,\cdots,t_b \leq t 的节点信息mi(t1),mi(tb)\mathbf{m}_{i}(t_1),\cdots \mathbf{m}_{i}(t_b),节点iitt 时刻的节点特征更新为:

mi(t)=agg(mi(t1),,mi(tb))\overline{\mathbf{m}}_{i}(t)=\operatorname{agg}\left(\mathbf{m}_{i}\left(t_{1}\right), \ldots, \mathbf{m}_{i}\left(t_{b}\right)\right)

聚合函数aggagg 有很多选择,例如 RNNs 或者注意力机制等等。论文中用到的是两种:most recent message 和 mean message。

Memory Updater

节点特征需要根据事件进行更新,边交互事件需要更新关联的节点对,节点事件需要更新对应节点,形式化表达如下:

si(t)=mem(mi(t),si(t))\mathbf{s}_{i}(t)=\operatorname{mem}\left(\overline{\mathbf{m}}_{i}(t), \mathbf{s}_{i}\left(t^{-}\right)\right)

其中memmem 更新函数是一个可学习的内存更新函数,例如 LSTM 或者 GRU 等。

Embedding

根据以上模块的节点状态更新,需要对节点ii 在时刻tt 的特征进行编码得到嵌入表示zi(t)\mathbf{z}_i(t)。形式化表达如下:

zi(t)=emb(i,t)=jηik([0,t])h(si(t),sj(t),eij,vi(t),vj(t))\mathbf{z}_{i}(t)=\operatorname{emb}(i, t)=\sum_{j \in \eta_{i}^{k}([0, t])} h\left(\mathbf{s}_{i}(t), \mathbf{s}_{j}(t), \mathbf{e}_{i j}, \mathbf{v}_{i}(t), \mathbf{v}_{j}(t)\right)

其中hh 是可学习的函数,对于不同的场景有不同的计算形式:

  • Identity(id)emb(i,t)=si(t)\operatorname{emb}(i, t) = \mathbf{s}_{i}(t),直接以节点特征作为 node embedding。
  • Time projection (time)emb(i,t)=(1+Δtw)si(t)\operatorname{emb}(i, t) = (1+\Delta t \mathbf{w})\cdot \mathbf{s}_i(t)Δt\Delta t 表示上次更新的时间间隔。
  • Temporal Graph Attention (attn):图注意力层,详细计算方式参考原文。
  • Temporal Graph Sum (sum):图节点特征快速融合。

模块训练流程

Experiments

论文中提出的模型在链路预测中的效果如下:

实验效果

动态节点分类的结果如下:

实验效果

从实验结果来看,论文中的方法远优于其它 baselines,论文开源的代码非常详细,感兴趣的同学可以自己动手实践下。

Thoughts

目前暂未接触动态图的预测任务,但是论文提出的方法创新性蛮好的,兼容了不同属性的图结构。

Contact