当前位置: 首页 > news >正文

图神经网络入门:手写一个 VanillaGNN-从邻接矩阵理解图神经网络的消息传递

🔹 第一部分:导入库

import torch
import pandas as pd
from torch_geometric.datasets import Planetoid
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj
  • torch:PyTorch 核心,用于张量和自动微分。
  • pandas:用于数据处理(虽然这段代码中未实际使用)。
  • Planetoid:加载 Cora/CiteSeer/PubMed 等标准图数据集。
  • Linear:全连接层。
  • F:包含 relulog_softmax 等函数。
  • to_dense_adj:将稀疏的 edge_index 转换为稠密邻接矩阵(dense adjacency matrix)。

🔹 第二部分:加载 Cora 数据并构建邻接矩阵

dataset = Planetoid(root="D:\\py机器学习\\data", name="Cora")
data = dataset[0]adjacency = to_dense_adj(data.edge_index)[0]
adjacency += torch.eye(len(adjacency))
  • 加载 Cora 数据集data 包含:

    • x: 节点特征(2708 × 1433)
    • y: 节点标签(2708)
    • edge_index: 边的稀疏表示(2 × E)
    • train_mask / val_mask / test_mask
  • to_dense_adj(data.edge_index)

    • 将稀疏边列表转为 N×N 稠密邻接矩阵(N=2708)
    • 返回形状为 [1, N, N] 的张量,所以用 [0] 取出第一个图(Cora 只有一个图)
  • adjacency += torch.eye(len(adjacency))

    • 添加自环(self-loops):让每个节点在聚合时包含自身信息
    • 这是图神经网络中的常见操作(如 GCN 的预处理步骤)

📌 注意:这里得到的是稠密矩阵,对于大图会非常耗内存!实际 GNN 通常用稀疏矩阵运算(如 torch.sparse.mm),但本代码为了教学简化使用了稠密形式(后续会转为稀疏)。


🔹 第三部分:定义准确率函数(与之前相同)

def accuracy(y_pred, y_true):return torch.sum(y_pred == y_true) / len(y_true)
  • 计算分类准确率,逻辑不变。

🔹 第四部分:定义图卷积层(VanillaGNNLayer)

class VanillaGNNLayer(torch.nn.Module):def __init__(self, dim_in, dim_out):super().__init__()self.linear = Linear(dim_in, dim_out, bias=False)def forward(self, x, adjacency):x = self.linear(x)x = torch.sparse.mm(adjacency, x)return x
✅ 关键点解析:
  1. 线性变换先于聚合

    • 先对每个节点特征做 W·x_iself.linear(x)
    • 再用邻接矩阵聚合邻居:A·(XW)
  2. 使用 torch.sparse.mm

    • 虽然 adjacency 是从稠密矩阵构造的,但 to_dense_adj 返回的是稠密张量
    • ❗ 但这里代码有潜在问题adjacency 是稠密的,而 torch.sparse.mm 要求第一个参数是稀疏张量
    • 实际上,这段代码会报错,除非将 adjacency 转为稀疏格式。

🔧 修正建议(但原代码可能在某些版本下侥幸运行):

adjacency = to_dense_adj(data.edge_index)[0]
adjacency += torch.eye(adjacency.size(0))
adjacency = adjacency.to_sparse()  # ← 必须加这行!
  1. 无偏置(bias=False)

    • 因为后续会加激活函数,且图卷积中常省略偏置以简化。
  2. 这是“消息传递”的简化版

    • 每个节点的新表示 = 所有邻居(含自己)的线性变换后的特征之和

🔹 第五部分:打印邻接矩阵(调试用)

print(adjacency)
  • 会输出一个 2708×2708 的大矩阵(或稀疏表示),主要用于调试,实际训练中不需要。

🔹 第六部分:定义完整 GNN 模型(VanillaGNN)

class VanillaGNN(torch.nn.Module):def __init__(self, dim_in, dim_h, dim_out):super().__init__()self.gnn1 = VanillaGNNLayer(dim_in, dim_h)self.gnn2 = VanillaGNNLayer(dim_h, dim_out)def forward(self, x, adjacency):h = self.gnn1(x, adjacency)h = torch.relu(h)h = self.gnn2(h, adjacency)return F.log_softmax(h, dim=1)
  • 两层图卷积:
    1. 第一层:1433 → 16 维
    2. ReLU 激活
    3. 第二层:16 → 7 维(类别数)
  • 输出使用 log_softmax,配合 NLLLoss

与 MLP 的本质区别
每一层都通过 adjacency 聚合邻居信息,利用了图结构


🔹 第七部分:训练方法(fit)

    def fit(self, data, epochs):criterion = torch.nn.NLLLoss()optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)self.train()for epoch in range(epochs+1):optimizer.zero_grad()out = self(data.x, adjacency)  # ← 关键:传入 adjacency!loss = criterion(out[data.train_mask], data.y[data.train_mask])acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])loss.backward()optimizer.step()if epoch % 20 == 0:val_loss = criterion(out[data.val_mask], data.y[data.val_mask])val_acc = accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | ...')
  • 与 MLP 的训练流程几乎相同,唯一区别是前向传播时传入了 adjacency
  • 说明:模型现在能利用图结构进行学习

🔹 第八部分:测试方法(test)

    @torch.no_grad()def test(self, data):self.eval()out = self(data.x, adjacency)acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])return acc
  • 标准测试流程,使用 test_mask 评估

🔹 第九部分:实例化并运行

gnn = VanillaGNN(dataset.num_features, 16, dataset.num_classes)
print(gnn)gnn.fit(data, epochs=100)
acc = gnn.test(data)
print(f'\nGNN test accuracy: {acc*100:.2f}%')
  • 创建模型:输入1433 → 隐藏16 → 输出7
  • 训练100轮
  • 测试并输出准确率(通常比 MLP 高,如 75%+)

🌟 总结:这个 GNN 的核心思想

步骤操作作用
1X → XW对每个节点特征做线性变换
2A·(XW)聚合所有邻居(含自己)的变换后特征
3ReLU引入非线性
4重复 1-3多层后,每个节点能“看到”更远的邻居

这其实就是 GCN(Graph Convolutional Network)的简化版
标准 GCN 还会对邻接矩阵做归一化(如 $\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2}$),但本模型省略了这一步,所以叫 “Vanilla”(朴素)GNN。


输出:

tensor([[1., 0., 0.,  ..., 0., 0., 0.],[0., 1., 1.,  ..., 0., 0., 0.],[0., 1., 1.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 1., 0., 0.],[0., 0., 0.,  ..., 0., 1., 1.],[0., 0., 0.,  ..., 0., 1., 1.]])
VanillaGNN((gnn1): VanillaGNNLayer((linear): Linear(in_features=1433, out_features=16, bias=False))(gnn2): VanillaGNNLayer((linear): Linear(in_features=16, out_features=7, bias=False))
)
Epoch   0 | Train Loss: 2.240 | Train Acc: 19.29% | Val Loss: 2.28 | Val Acc: 14.20%
Epoch  20 | Train Loss: 0.117 | Train Acc: 100.00% | Val Loss: 1.43 | Val Acc: 74.20%
Epoch  40 | Train Loss: 0.009 | Train Acc: 100.00% | Val Loss: 2.29 | Val Acc: 73.40%
Epoch  60 | Train Loss: 0.002 | Train Acc: 100.00% | Val Loss: 2.51 | Val Acc: 74.60%
Epoch  80 | Train Loss: 0.002 | Train Acc: 100.00% | Val Loss: 2.47 | Val Acc: 74.80%
Epoch 100 | Train Loss: 0.001 | Train Acc: 100.00% | Val Loss: 2.42 | Val Acc: 75.40%GNN test accuracy: 76.10%

✅ 与之前 MLP 的对比

MLP基线模型:

https://blog.csdn.net/sweet_ran/article/details/154017794?fromshare=blogdetail&sharetype=blogdetail&sharerId=154017794&sharerefer=PC&sharesource=sweet_ran&sharefrom=from_linkhttps://blog.csdn.net/sweet_ran/article/details/154017794?fromshare=blogdetail&sharetype=blogdetail&sharerId=154017794&sharerefer=PC&sharesource=sweet_ran&sharefrom=from_link

模型是否用图结构典型测试准确率(Cora)特点
MLP❌ 否~50–60%仅用节点特征
VanillaGNN✅ 是~70–80%利用邻居信息,性能显著提升

http://www.dtcms.com/a/540692.html

相关文章:

  • 网站模版带后台酒类招商网站大全
  • 营销型网站创建网页制作三剑客通常指
  • 【笔试真题】- 电信-2025.10.11
  • 云渲染与传统渲染:核心差异与适用场景分析
  • 什么是流程监控?如何构建跨系统BPM的实时监控体系?
  • 直通滤波....
  • eclipse做网站代码惠州市
  • 零基础新手小白快速了解掌握服务集群与自动化运维(十五)Redis模块-Redis主从复制
  • 视频网站自己怎么做的正规的大宗商品交易平台
  • vue3 实现贪吃蛇手机版01
  • 胶州网站建设dch100室内装修设计师工资一般多少钱
  • 计算机视觉、医学图像处理、深度学习、多模态融合方向分析
  • 小白入门:基于k8s搭建训练集群,实战CIFAR-10图像分类
  • 关系型数据库大王Mysql——DML语句操作示例
  • VNC安装
  • 网站建设论文 php苏州关键词排名提升
  • 【MySQL】用户管理详解
  • 怎么制作手机网站金坛区建设工程质量监督网站
  • 企业网站的布局类型怎样免费建设免费网站
  • Unity UGC IDE实现深度解析(一):节点图的核心架构设计
  • h5游戏免费下载:搭汉堡
  • 中外商贸网站建设网站怎样做权重
  • 做雇主品牌的网站logo设计网页
  • RocketMQ核心技术精讲-----详解消息发送样例
  • 解锁 PySpark SQL 的强大功能:有关 App Store 数据的端到端教程
  • MousePlus(鼠标增强工具) 中文绿色版
  • 源码学习:MyBatis源码深度解析与实战
  • RAG项目中知识库的检索优化
  • Java IO 流之转换流:InputStreamReader/OutputStreamWriter(字节与字符的桥梁)
  • 熊掌号做网站推广的注意事项品牌网页