基于 PyG 构造消息传递网络
图上的卷积操作主要包含两部分:节点消息传递与消息聚集。假设xi(k−1)∈RF 表示k−1层节点的特征,ej,i∈RD表示节点j到节点i 的边的特征。那么消息传递的图神经网络可以表示为:
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))
其中□ 表示可微分的排列不变函数,e.g. sum,mean,max。λ和γ 表示可微分的函数,e.g. MLPs。
PyG 中 torch_geometric.nn.MessagePassing 提供一系列的消息传递方法来自动处理消息传播过程。
接下来以构造经典的kipf提出的GCN为例。
实现 GCN 层
GCN层的数学定义如下:
xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1))
由上式可知节点特征首先经过Θ 进行特征变换,然后根据度进行归一化然后求和。计算公式可以拆分为以下几步:
- 邻接矩阵A 添加自环。
- 节点特征矩阵的线性变换。
- 计算归一化系数。
- 归一化节点特征。ϕ
- 邻居节点求和。(
add 聚集) - 得到最后的节点嵌入向量。γ
以上过程基于 PyG 的实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(GCNConv, self).__init__(aggr='add') self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index deg = degree(row, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
|
定义好卷积层后即可调用卷积层进行堆叠:
1 2
| conv = GCNConv(16, 32) x = conv(x, edge_index)
|
实现边卷积
这种方式个人用得较少,简单记录下。对于点云数据的卷积定义为:
xi(k)=j∈N(i)maxhΘ(xi(k−1),xj(k−1)−xi(k−1))
基于 PyG 实现方式如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| import torch from torch.nn import Sequential as Seq, Linear, ReLU from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing): def __init__(self, in_channels, out_channels): super(EdgeConv, self).__init__(aggr='max') self.mlp = Seq(Linear(2 * in_channels, out_channels), ReLU(), Linear(out_channels, out_channels))
def forward(self, x, edge_index):
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_i, x_j):
tmp = torch.cat([x_i, x_j - x_i], dim=1) return self.mlp(tmp)
def update(self, aggr_out):
return aggr_out
|
1 2 3 4 5 6 7 8 9 10
| from torch_geometric.nn import knn_graph
class DynamicEdgeConv(EdgeConv): def __init__(self, in_channels, out_channels, k=6): super(DynamicEdgeConv, self).__init__(in_channels, out_channels) self.k = k
def forward(self, x, batch=None): edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow) return super(DynamicEdgeConv, self).forward(x, edge_index)
|
1 2
| conv = DynamicEdgeConv(3, 128, k=6) x = conv(pos, batch)
|
了解以上内容即可知道如何自定义 GNN 计算方式。
更多内容参考官网教程:https://pytorch-geometric.readthedocs.io/en/latest/index.html
联系作者
