06.消息传递网络
消息传递图神经网络可以描述为:
xi表示第i和节点,k表示第k层图神经网络。⊕表示某种函数,如求和、均值或最大值。γ和Φ也表示某种可微分的函数,例如多层感知器(MLPs)。
这个公式的意思就是,第k层网络中的第i个节点的特征等于,第k-1层中i节点周围相连接的所有j节点和边(如果有)的特征的某种函数关系(看Φ取什么)的和、均值或最大值(看⊕取什么),再加上(看γ取什么)第k-1层中i节点本身的特征。
一、“消息传递”基类
PyG 提供了 MessagePassing
基类,它通过自动处理消息传递来帮助创建这类消息传递图神经网络。用户只需定义函数Φ ,即 message()
,以及函数 γ,即 update()
,以及要使用的聚合方案⊕,即 aggr="add"
、 aggr="mean"
或 aggr="max"
。
具体可参考https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html
-
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
: 在 PyTorch Geometric 中,torch_geometric.nn.MessagePassing
是所有图神经网络层(如 GCN、GAT、GIN)的基类,用于实现消息传递机制。定义了要使用的聚合方案("add"
、"mean"
或"max"
)和消息传递的流向("source_to_target"
或"target_to_source"
)。此外,node_dim
属性指示沿哪个轴进行传播。 -
MessagePassing.propagate(edge_index, size=None, **kwargs)
: 开始传播消息的初始调用。图神经网络消息传递的核心函数,用于在MessagePassing
子类中自动执行 消息构造(message)→ 聚合(aggregate)→ 更新(update) 三个阶段的操作。 -
MessagePassing.message (...)
:构造消息并发送到节点 i,其方式类似于针对每条边构建映射 。若 flow 为 “source_to_target”,则针对边 (j, i) ∈ ℰ ;若 flow 为 “target_to_source”,则针对边 (i, j) ∈ ℰ 。它可以接收最初传递给 propagate () 的任意参数。此外,传递给 propagate () 的张量可通过在变量名后附加 _i 或 _j ,映射到相应的节点 i 和 j,例如 x_i 和 x_j 。需要注意的是,通常我们将 i 称为聚合信息的中心节点,将 j 称为相邻节点,因为这是最常用的表示方法 。 -
MessagePassing.update(aggr_out, ...)
:类似于 γ(的操作),为每个节点i更新节点嵌入。将聚合的输出作为第一个参数,还可接收最初传递给 propagate () 的任意参数 。
二、实现 GCN 层
GCN 层在数学上定义为:
其中,邻近节点的特征首先通过权重矩阵W进行转换,然后根据其度数进行归一化,并最终求和。最后,我们将偏置向量b应用于聚合后的输出。这个公式可以分为以下步骤:
- 在邻接矩阵中添加自循环(因为i节点要聚合自身的特征)。
- 对节点特征矩阵进行线性变换(即通过一层线性神经网络进行变换)。
- 计算归一化系数(两个根号相乘的倒数)。
- 求和邻近节点特征(
"add"
聚合,也包含自身,因为添加自环边后,它自己本身也算是它的邻居)。 - 应用最终偏置向量。
步骤 1-3 通常在消息传递发生之前计算。步骤 4-5 可以使用 MessagePassing
基础类轻松处理。完整的层实现如下:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn