当前位置: 首页 > news >正文

「日拱一码」027 深度学习库——PyTorch Geometric(PyG)

目录

数据处理与转换

数据表示

数据加载

数据转换

特征归一化

添加自环

随机扰动

组合转换

图神经网络层

图卷积层(GCNConv)

图注意力层(GATConv)

池化

全局池化(Global Pooling)

全局平均池化

全局最大池化

全局求和池化

基于注意力的池化(Attention-based Pooling)

基于图的池化(Graph-based Pooling)

层次化池化(Hierarchical Pooling)

采样

子图采样(Subgraph Sampling)

邻域采样(Neighbor Sampling)

模型训练与评估

训练过程

测试过程

异构图处理

异构图定义

异构图卷积

图生成模型

Deep Graph Infomax (DGI)

Graph Autoencoder (GAE)

Variational Graph Autoencoder (VGAE)


PyTorch Geometric(PyG)是PyTorch的一个扩展库,专注于图神经网络(GNN)的实现。它提供了丰富的图数据处理工具、图神经网络层和模型。以下是对PyG库中常用方法的介绍

数据处理与转换

数据表示

PyG使用 torch_geometric.data.Data 类来表示图数据,包含节点特征 x 、边索引 edge_index 、边特征 edge_attr 等

## 数据处理与转换
# 1. 数据表示
import torch
from torch_geometric.data import Data# 创建一个简单的图
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float)  # 节点特征
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)  # 边索引
edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float)  # 边特征data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(data)  # Data(x=[3, 2], edge_index=[2, 4], edge_attr=[4])

数据加载

PyG提供了 torch_geometric.data.DataLoader 类,用于批量加载图数据

# 2. 数据加载
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader# 加载Cora数据集
dataset = Planetoid(root='./data', name='Cora')
loader = DataLoader(dataset, batch_size=32, shuffle=True)print(f"节点数: {data.num_nodes}")  # 3
print(f"边数: {data.num_edges}")  # 4
print(f"特征维度: {data.num_node_features}")  # 2
print(f"类别数: {dataset.num_classes}")  # 7for batch in loader:print(batch)# DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708],#           batch=[2708], ptr=[2])

数据转换

  • 特征归一化

NormalizeFeatures  是一个常用的转换方法,用于将节点特征归一化到单位范数(如 0, 1 或 -1, 1)

# 3. 数据转换
# 3.1 特征归一化
from torch_geometric.transforms import NormalizeFeaturesdataset = Planetoid(root='./data', name='Cora', transform=NormalizeFeatures())# 查看归一化后的特征
data = dataset[0]
print(data.x)
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
  • 添加自环

AddSelfLoops  是一个转换方法,用于为图中的每个节点添加自环(即每个节点连接到自己)

# 3.2 添加自环
from torch_geometric.transforms import AddSelfLoopsdataset = Planetoid(root='./data', name='Cora', transform=AddSelfLoops())# 查看添加自环后的边索引
data = dataset[0]
print(data.edge_index)
# tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
#         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]])
  • 随机扰动

RandomNodeSplit  是一个转换方法,用于随机划分训练集、验证集和测试集

# 3.3 随机扰动
from torch_geometric.transforms import RandomNodeSplitdataset = Planetoid(root='./data', name='Cora', transform=RandomNodeSplit(num_splits=10))# 查看划分后的掩码
data = dataset[0]
print(data.train_mask)
# tensor([[False,  True,  True,  ..., False, False,  True],
#         [False, False,  True,  ...,  True, False, False],
#         [False,  True, False,  ..., False, False, False],
#         ...,
#         [ True,  True,  True,  ..., False, False, False],
#         [ True,  True,  True,  ..., False, False,  True],
#         [ True,  True,  True,  ...,  True, False,  True]])
print(data.val_mask)
# tensor([[False, False, False,  ..., False,  True, False],
#         [False, False, False,  ..., False, False, False],
#         [False, False, False,  ..., False, False,  True],
#         ...,
#         [False, False, False,  ...,  True,  True, False],
#         [False, False, False,  ..., False,  True, False],
#         [False, False, False,  ..., False, False, False]])
print(data.test_mask)
# tensor([[ True, False, False,  ...,  True, False, False],
#         [ True,  True, False,  ..., False,  True,  True],
#         [ True, False,  True,  ...,  True,  True, False],
#         ...,
#         [False, False, False,  ..., False, False,  True],
#         [False, False, False,  ...,  True, False, False],
#         [False, False, False,  ..., False,  True, False]])
  • 组合转换

可以将多个转换方法组合在一起,形成一个复合转换

# 3.4 组合转换
from torch_geometric.transforms import Compose, NormalizeFeatures, AddSelfLoops# 定义一个复合转换
transform = Compose([NormalizeFeatures(), AddSelfLoops()])# 创建一个数据集,并应用复合转换
dataset = Planetoid(root='./data', name='Cora', transform=transform)# 查看转换后的数据
data = dataset[0]
print(data.x)
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
print(data.edge_index)
# tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
#         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]])

图神经网络层

图卷积层(GCNConv)

GCNConv是图卷积网络(GCN)的基本层

## 图神经网络层
# 1. 图卷积层 GCNConv
import torch
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self, in_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, 16)self.conv2 = GCNConv(16, out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = torch.relu(x)x = self.conv2(x, edge_index)return xmodel = GCN(in_channels=dataset.num_features, out_channels=dataset.num_classes)
print(model)
# GCN(
#   (conv1): GCNConv(1433, 16)
#   (conv2): GCNConv(16, 7)
# )

图注意力层(GATConv)

GATConv是图注意力网络(GAT)的基本层

# 2. 图注意力层 GATConv
from torch_geometric.nn import GATConvclass GAT(torch.nn.Module):def __init__(self, in_channels, out_channels):super(GAT, self).__init__()self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=True, dropout=0.6)def forward(self, x, edge_index):x = torch.dropout(x, p=0.6, training=self.training)x = self.conv1(x, edge_index)x = torch.relu(x)x = torch.dropout(x, p=0.6, training=self.training)x = self.conv2(x, edge_index)return xmodel = GAT(in_channels=dataset.num_features, out_channels=dataset.num_classes)
print(model)
# GAT(
#   (conv1): GATConv(1433, 8, heads=8)
#   (conv2): GATConv(64, 7, heads=1)
# )

池化

全局池化(Global Pooling)

全局池化将整个图的所有节点聚合为一个全局表示

  • 全局平均池化
## 池化
# 1. 全局池化
# 1.1 全局平均池化
from torch_geometric.nn import global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_mean = global_mean_pool(x, batch_index)print("Global Mean Pooling Result:", global_mean)break
# tensor([[0.7647, 0.0588, 0.1176, 0.0000, 0.0588, 0.0000, 0.0000],
#         [0.7500, 0.1250, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6250, 0.1250, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5217, 0.1739, 0.3043, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5455, 0.2727, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6400, 0.1200, 0.2400, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6364, 0.0909, 0.2727, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7857, 0.0714, 0.1429, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8000, 0.0667, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7647, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0588],
#         [0.5000, 0.1667, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7692, 0.0769, 0.1538, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7826, 0.0435, 0.1739, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8000, 0.0500, 0.1500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6667, 0.0833, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8125, 0.0625, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8235, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6000, 0.1000, 0.2000, 0.0000, 0.0000, 0.1000, 0.0000],
#         [0.4615, 0.1538, 0.3077, 0.0000, 0.0000, 0.0769, 0.0000],
#         [0.7647, 0.0588, 0.1765, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6000, 0.0500, 0.2000, 0.0000, 0.0000, 0.1500, 0.0000],
#         [0.6000, 0.1000, 0.2000, 0.1000, 0.0000, 0.0000, 0.0000],
#         [0.7647, 0.0588, 0.1176, 0.0000, 0.0000, 0.0588, 0.0000],
#         [0.8000, 0.0667, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.4615, 0.1538, 0.3077, 0.0769, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8421, 0.0526, 0.1053, 0.0000, 0.0000, 0.0000, 0.0000]])
  • 全局最大池化
# 1.2 全局最大池化
from torch_geometric.nn import global_max_pool# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_max = global_max_pool(x, batch_index)print("Global Max Pooling Result:", global_max)break
  • 全局求和池化
# 3. 全局求和池化
from torch_geometric.nn import global_add_pool# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_sum = global_add_pool(x, batch_index)print("Global Sum Pooling Result:", global_sum)break

基于注意力的池化(Attention-based Pooling)

基于注意力的池化方法通过学习节点的重要性权重来进行池化。一个常见的例子是 Set2Set 池化

# 2. 基于注意力的池化——Set2Set
from torch_geometric.nn import Set2Set
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定义 Set2Set 池化
set2set = Set2Set(in_channels=dataset.num_node_features, processing_steps=3)# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_set2set = set2set(x, batch_index)print("Set2Set Pooling Result:", global_set2set)break
# Set2Set Pooling Result: tensor([[ 0.1719,  0.0986,  0.1594, -0.0438,  0.1743,  0.1663, -0.0578,  0.8464,
#           0.0492,  0.1045,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1730,  0.0987,  0.1601, -0.0420,  0.1730,  0.1658, -0.0549,  0.8733,
#           0.0405,  0.0862,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1686,  0.0919,  0.1603, -0.0525,  0.1807,  0.1707, -0.0683,  0.7540,
#           0.0466,  0.1994,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1601,  0.1165,  0.1425, -0.0525,  0.1782,  0.1602, -0.0836,  0.6232,
#           0.2237,  0.1531,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1725,  0.0987,  0.1598, -0.0428,  0.1736,  0.1660, -0.0562,  0.8611,
#           0.0444,  0.0945,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1673,  0.1060,  0.1527, -0.0473,  0.1761,  0.1642, -0.0679,  0.7570,
#           0.1187,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1579,  0.0996,  0.1486, -0.0658,  0.1874,  0.1695, -0.0954,  0.5284,
#           0.1662,  0.3054,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1584,  0.0969,  0.1503, -0.0665,  0.1881,  0.1709, -0.0949,  0.5327,
#           0.1503,  0.3170,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1634,  0.0976,  0.1537, -0.0581,  0.1835,  0.1695, -0.0809,  0.6464,
#           0.1135,  0.2401,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1704,  0.0952,  0.1599, -0.0479,  0.1774,  0.1684, -0.0626,  0.8042,
#           0.0466,  0.1492,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1634,  0.1081,  0.1488, -0.0522,  0.1789,  0.1640, -0.0776,  0.6743,
#           0.1595,  0.1661,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1564,  0.1193,  0.1384, -0.0562,  0.1800,  0.1590, -0.0922,  0.5527,
#           0.2663,  0.1810,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1704,  0.0952,  0.1599, -0.0479,  0.1774,  0.1684, -0.0626,  0.8042,
#           0.0466,  0.1492,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1730,  0.0987,  0.1601, -0.0420,  0.1730,  0.1658, -0.0549,  0.8733,
#           0.0405,  0.0862,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1604,  0.0972,  0.1517, -0.0631,  0.1863,  0.1704, -0.0893,  0.5779,
#           0.1356,  0.2864,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1711,  0.0985,  0.1589, -0.0451,  0.1752,  0.1666, -0.0599,  0.8281,
#           0.0550,  0.1169,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1584,  0.1178,  0.1406, -0.0542,  0.1790,  0.1597, -0.0875,  0.5910,
#           0.2432,  0.1659,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1673,  0.1060,  0.1527, -0.0473,  0.1761,  0.1642, -0.0679,  0.7570,
#           0.1187,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1602,  0.1097,  0.1457, -0.0562,  0.1811,  0.1638, -0.0856,  0.6077,
#           0.1926,  0.1997,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1696,  0.1047,  0.1549, -0.0444,  0.1743,  0.1641, -0.0623,  0.8062,
#           0.0945,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1604,  0.0972,  0.1517, -0.0631,  0.1863,  0.1704, -0.0893,  0.5779,
#           0.1356,  0.2864,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1638,  0.0919,  0.1567, -0.0605,  0.1855,  0.1723, -0.0815,  0.6416,
#           0.0853,  0.2731,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1732,  0.1049,  0.1525, -0.0508,  0.1755,  0.1624, -0.0665,  0.7700,
#           0.0553,  0.1160,  0.0000,  0.0000,  0.0586,  0.0000],
#         [ 0.1711,  0.0985,  0.1589, -0.0451,  0.1752,  0.1666, -0.0599,  0.8281,
#           0.0550,  0.1169,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1573,  0.0968,  0.1494, -0.0685,  0.1891,  0.1712, -0.0982,  0.5063,
#           0.1589,  0.3349,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1729,  0.1053,  0.1365, -0.0582,  0.1904,  0.1594, -0.0881,  0.5637,
#           0.0878,  0.1812,  0.0746,  0.0000,  0.0927,  0.0000],
#         [ 0.1586,  0.1026,  0.1477, -0.0628,  0.1855,  0.1678, -0.0924,  0.5526,
#           0.1742,  0.2733,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1646,  0.1075,  0.1500, -0.0506,  0.1781,  0.1641, -0.0746,  0.6999,
#           0.1469,  0.1533,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1695,  0.0983,  0.1579, -0.0477,  0.1770,  0.1672, -0.0641,  0.7909,
#           0.0670,  0.1421,  0.0000,  0.0000,  0.0000,  0.0000]],
#        grad_fn=<CatBackward0>)

基于图的池化(Graph-based Pooling)

基于图的池化方法通过图的结构信息来进行池化。常见的方法包括 TopKPooling,通过选择重要性最高的节点来进行池化

# 3. 基于图的池化——TopKPooling
from torch_geometric.nn import TopKPooling
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定义 TopKPooling
pool = TopKPooling(in_channels=dataset.num_node_features, ratio=0.5) # 获取一个批次的数据
for batch in loader:x = batch.xedge_index = batch.edge_indexbatch_index = batch.batchx, edge_index, _, batch_index, _, _ = pool(x, edge_index, batch=batch_index)print("TopKPooling Result:", x)break
# tensor([[-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         ...,
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000]],
#        grad_fn=<MulBackward0>)

层次化池化(Hierarchical Pooling)

层次化池化通过多层池化操作生成图的层次化表示。一个常见的例子是 EdgePooling,通过边的合并操作来进行池化

# 4. 层次化池化——EdgePooling
from torch_geometric.nn import EdgePooling
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定义 EdgePooling
pool = EdgePooling(in_channels=dataset.num_node_features)  # 获取一个批次的数据
for batch in loader:x = batch.xedge_index = batch.edge_indexbatch_index = batch.batchx, edge_index, batch_index, _ = pool(x, edge_index, batch=batch_index)print("EdgePooling Result:", x)break# tensor([[0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         ...,
#         [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
#         [1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000]],
#        grad_fn=<MulBackward0>)

采样

子图采样(Subgraph Sampling)

子图采样是从原始图中提取一个子图,通常用于减少计算复杂度和增强模型的泛化能力

## 采样
# 1. 子图采样
import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import k_hop_subgraph# 加载数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]# 选择一个起始节点
start_node = 0
num_hops = 2  # 采样半径# 提取子图
sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(start_node, num_hops, data.edge_index)# 创建子图
sub_data = Data(x=data.x[sub_nodes], edge_index=sub_edge_index, y=data.y[sub_nodes])print("Original Graph Nodes:", data.num_nodes)  # 2708
print("Subgraph Nodes:", sub_data.num_nodes)  # 8
print("Subgraph Edges:", sub_data.edge_index.shape[1])  # 20

邻域采样(Neighbor Sampling)

邻域采样通过选择节点的邻居来生成子图,适用于大规模图数据

# 2. 邻域采样
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader# 加载数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]# 定义 NeighborSampler
loader = NeighborLoader(data,num_neighbors=[10, 10],  # 每层采样的邻居数量batch_size=1024,shuffle=True,
)# 遍历数据加载器
for batch in loader:print(batch)break

模型训练与评估

训练过程

## 模型训练与评估
# 1. 训练过程
import torch.nn.functional as Foptimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss

测试过程

# 2. 测试过程
@torch.no_grad()
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())acc = correct / int(data.test_mask.sum())return accfor epoch in range(200):loss = train()acc = test()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

异构图处理

异构图定义

## 异构图处理
# 1. 异构图定义
from torch_geometric.data import HeteroData
import torchdata = HeteroData()
# 添加两种类型节点
data['user'].x = torch.randn(4, 16)  # 4个用户
data['movie'].x = torch.randn(5, 32)  # 5部电影
# 添加边
data['user', 'rates', 'movie'].edge_index = torch.tensor([[0, 0, 1, 2, 3], [0, 2, 3, 1, 4]]  # user->movie评分关系
)

异构图卷积

# 2. 异构图卷积
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
from torch_geometric.transforms import NormalizeFeaturesclass HeteroGNN(torch.nn.Module):def __init__(self, in_channels, out_channels, hidden_channels):super().__init__()self.conv1 = HeteroConv({('user', 'rates', 'movie'): SAGEConv((in_channels['user'], in_channels['movie']), hidden_channels),('movie', 'rev_rates', 'user'): GCNConv(in_channels['movie'], hidden_channels, add_self_loops=False)  # 禁用自环}, aggr='sum')self.conv2 = HeteroConv({('user', 'rates', 'movie'): SAGEConv((hidden_channels, hidden_channels), out_channels),('movie', 'rev_rates', 'user'): GCNConv(hidden_channels, out_channels, add_self_loops=False)  # 禁用自环}, aggr='sum')def forward(self, x_dict, edge_index_dict):x_dict = self.conv1(x_dict, edge_index_dict)x_dict = {key: torch.relu(x) for key, x in x_dict.items()}x_dict = self.conv2(x_dict, edge_index_dict)return x_dict# 定义输入和输出通道数
in_channels = {'user': 16, 'movie': 32}
out_channels = 7  # 假设输出通道数为7
hidden_channels = 64  # 假设隐藏层通道数为64# 实例化模型
model = HeteroGNN(in_channels, out_channels, hidden_channels)
print(model)
# HeteroGNN(
#   (conv1): HeteroConv(num_relations=2)
#   (conv2): HeteroConv(num_relations=2)
# )

图生成模型

Deep Graph Infomax (DGI)

DGI 是一种无监督图表示学习方法,通过最大化局部和全局图表示之间的一致性来学习节点嵌入

## 图生成模型
# 1. Deep Graph Infomax (DGI)
from torch_geometric.nn import DeepGraphInfomax
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn as nn
import torch.nn.functional as Fclass Encoder(nn.Module):def __init__(self, in_channels, hidden_channels):super(Encoder, self).__init__()self.conv = GCNConv(in_channels, hidden_channels)self.prelu = nn.PReLU(hidden_channels)def forward(self, x, edge_index):x = self.conv(x, edge_index)x = self.prelu(x)return xdef corruption(x, edge_index):return x[torch.randperm(x.size(0))], edge_indexdataset = Planetoid(root='./data', name='Cora')
data = dataset[0]encoder = Encoder(dataset.num_features, hidden_channels=512)
model = DeepGraphInfomax(hidden_channels=512, encoder=encoder,summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),corruption=corruption
)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()pos_z, neg_z, summary = model(data.x, data.edge_index)loss = model.loss(pos_z, neg_z, summary)loss.backward()optimizer.step()return lossfor epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

Graph Autoencoder (GAE)

GAE 是一种基于图神经网络的自编码器,用于图生成任务。它通过学习节点嵌入来重建图的邻接矩阵

# 2. Graph Autoencoder(GAE)
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GAE
import torch.nn.functional as Fclass Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = GCNConv(in_channels, 2 * out_channels)self.conv2 = GCNConv(2 * out_channels, out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)return self.conv2(x, edge_index)dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]encoder = Encoder(dataset.num_features, out_channels=16)
model = GAE(encoder)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()z = model.encode(data.x, data.edge_index)loss = model.recon_loss(z, data.edge_index)loss.backward()optimizer.step()return lossfor epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

Variational Graph Autoencoder (VGAE)

VGAE 是 GAE 的变体,通过引入变分推断来学习节点嵌入的分布

# 3. Variational Graph Autoencoder(VGAE)
from torch_geometric.nn import VGAE
from torch_geometric.datasets import Planetoid# 定义数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]class Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = GCNConv(in_channels, 2 * out_channels)self.conv2 = GCNConv(2 * out_channels, 2 * out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)mu = x[:, :x.size(1) // 2]logstd = x[:, x.size(1) // 2:]return mu, logstd# 定义 Encoder
encoder = Encoder(dataset.num_features, out_channels=16)# 定义 VGAE 模型
model = VGAE(encoder)# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练函数
def train():model.train()optimizer.zero_grad()z = model.encode(data.x, data.edge_index)loss = model.recon_loss(z, data.edge_index)kl_loss = model.kl_loss()loss += kl_lossloss.backward()optimizer.step()return loss# 训练模型
for epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')
    http://www.dtcms.com/a/279327.html

    相关文章:

  • MCP基础知识二(实战通信方式之Streamable HTTP)
  • 【CTF学习】PWN基础工具的使用(binwalk、foremost、Wireshark、WinHex)
  • ewdyfdfytty
  • LangChain教程——文本嵌入模型
  • 20250714让荣品RD-RK3588开发板在Android13下长按关机
  • Debezium日常分享系列之:提升Debezium性能
  • 制造业实战:数字化集采如何保障千种备件“不断供、不积压”?
  • 16.避免使用裸 except
  • MFC扩展库BCGControlBar Pro v36.2新版亮点:可视化设计器升级
  • 计算机毕业设计Java轩辕购物商城管理系统 基于 SpringBoot 的轩辕电商商城管理系统 Java 轩辕购物平台管理系统设计与实现
  • 面向对象的设计模式
  • 【数据结构】树(堆)·上
  • js的局部变量和全局变量
  • 测试驱动开发(TDD)实战:在 Spring 框架实现中践行 “红 - 绿 - 重构“ 循环
  • Bash vs PowerShell | 从 CMD 到跨平台工具:Bash 与 PowerShell 的全方位对比
  • vue3 服务端渲染时请求接口没有等到数据,但是客户端渲染是请求接口又可以得到数据
  • 7.14 map | 内存 | 二维dp | 二维前缀和
  • python+Request提取cookie
  • 电脑升级Experience
  • python transformers笔记(Trainer类)
  • 代码随想录算法训练营第三十五天|416. 分割等和子集
  • LLM表征工程还有哪些值得做的地方
  • 内部文件审计:企业文件服务器审计对网络安全提升有哪些帮助?
  • 防火墙技术概述
  • Qt轮廓分析设计+算法+避坑
  • Redis技术笔记-主从复制、哨兵与持久化实战指南
  • 第五章 uniapp实现兼容多端的树状族谱关系图,剩余组件
  • 学习C++、QT---25(QT中实现QCombobox库的介绍和用QCombobox设置编码和使用编码的讲解)
  • SQL ORM映射框架深度剖析:从原理到实战优化
  • 【Unity】MiniGame编辑器小游戏(十三)最强射手【Shooter】(下)