1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(GCNConv, self).__init__(aggr='add') self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index deg = degree(row, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
|