当前位置: 首页 > news >正文

TensorFlow深度学习实战——自定义图数据集

TensorFlow深度学习实战——自定义图数据集

    • 0. 前言
    • 1. 自定义图数据集
    • 2. 单图数据集
    • 3. 多图数据集
    • 相关链接

0. 前言

我们已经学习了如何构建并训练图神经网络 (Graph Neural Network, GNN) 以解决常见的图机器学习任务。实现过程中,我们选择使用预构建的 Deep Graph Library (DGL) 数据集,但针对具体问题,我们可能会需要 DGL 库中未提供的数据集。实践中,我们通常需要使用自己的数据,因此,本节将学习如何将自己的数据转换为 DGL 数据集。

1. 自定义图数据集

在现实世界中,通常需要使用自己的数据训练图神经网络 (Graph Neural Network, GNN) 模型。显然,在这种情况下,无法使用 Deep Graph Library (DGL) 提供的数据集,而必须将数据封装成自定义图数据集。自定义图数据集应继承自 DGLdgl.data.DGLDataset 类,并实现以下方法:

  • getitem(self, i) – 检索数据集中的第 i 个样本,检索到的样本包含一个 DGL 图及其标签
  • len(self) – 数据集中样本的数量
  • process(self) – 定义如何从磁盘加载和处理原始数据。

节点分类和链接预测在单个图上进行操作,而图分类则在一组图上进行操作,接下来介绍如何自定义这两种类型的图数据集。

2. 单图数据集

在本节中,我们将使用 Zachary’s Karate Club 图,该图表示在三年内空手道俱乐部成员。成员根据管理员 (Officer) 和教练 (Mr. Hi) 进行分组,管理员和教练之间产生了冲突,一半会员跟随教练,剩下一半会员跟随管理员,下图中分别标记为蓝色和红色节点,Zachary's Karate Club 图可从 NetworkX 库中加载:

import dgl
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tffrom dgl.data import DGLDatasetG = nx.karate_club_graph()nodes = [n for n in G.nodes]
label_colors = {"Mr. Hi": "red", "Officer": "blue"}
labels = [label_colors[G.nodes[n]["club"]] for n in G.nodes]nx.draw(G, node_color=labels, with_labels=True)
plt.show()

Zachary

该图包含 34 个节点,每个节点代表一个空手道俱乐部的成员,标记为 OfficerMr. Hi,具体取决于他们属于哪个组。该图包含 78 条无向、无权边,成员之间的边表示他们在俱乐部外部仍存在社交关系。为了使此数据集适合 GNN 使用,为每个节点添加一个 10 维的随机特征向量。将 Karate Club 图转换为 DGL 数据集,将其用于下游的节点分类或链路预测任务:

class KarateClubDataset(DGLDataset):def __init__(self):super().__init__(name="karate_club")def __getitem__(self, i):return self.graphdef __len__(self):return 1def process(self):G = nx.karate_club_graph()nodes = [node for node in G.nodes]edges = [edge for edge in G.edges]node_features = tf.random.uniform((len(nodes), 10), minval=0, maxval=1, dtype=tf.dtypes.float32)label2int = {"Mr. Hi": 0, "Officer": 1}node_labels = tf.convert_to_tensor([label2int[G.nodes[node]["club"]] for node in nodes])edge_features = tf.random.uniform((len(edges), 1), minval=3, maxval=10, dtype=tf.dtypes.int32)edges_src = tf.convert_to_tensor([u for u, v in edges])edges_dst = tf.convert_to_tensor([v for u, v in edges])self.graph = dgl.graph((edges_src, edges_dst), num_nodes=len(nodes))self.graph.ndata["feat"] = node_featuresself.graph.ndata["label"] = node_labelsself.graph.edata["weight"] = edge_features# assign masks indicating the split (training, validation, test)n_nodes = len(nodes)n_train = int(n_nodes * 0.6)n_val = int(n_nodes * 0.2)train_mask = tf.convert_to_tensor(np.hstack([np.ones(n_train), np.zeros(n_nodes - n_train)]), dtype=tf.bool)val_mask = tf.convert_to_tensor(np.hstack([np.zeros(n_train), np.ones(n_val), np.zeros(n_nodes - n_train - n_val)]),dtype=tf.bool)test_mask = tf.convert_to_tensor(np.hstack([np.zeros(n_train + n_val), np.ones(n_nodes - n_train - n_val)]),dtype=tf.bool)self.graph.ndata["train_mask"] = train_maskself.graph.ndata["val_mask"] = val_maskself.graph.ndata["test_mask"] = test_mask

process() 方法中。调用 NetworkX 方法获取空手道俱乐部图作为 NetworkX 图,然后将其转换为包含节点特征和标签的 DGL 图对象。空手道俱乐部图没有定义节点和边的特征,需要生成随机特征向量作为节点特征和边特征。需要注意的是,这仅用于演示目的,以说明如果图具有节点和边特征时如何进行更新,且该数据集仅包含一个图。
此外,还需要将图划分为训练、验证和测试集,用于节点分类任务。为此,我们分配掩码以指示一个节点属于划分数据集的哪一集合。将图中的节点按 60/20/20 的比例划分,并为每个划分数据集分配布尔掩码。
实例化数据集 KarateClubDataset()

dataset = KarateClubDataset()
g = dataset[0]
print(g)

输出结果如下所示。两个主要结构是 ndata_schemasedata_schemas,分别可以通过 g.ndatag.edata 访问。在 ndata_schemas 中,有指向节点特征 (feats)、节点标签 (label) 以及指示训练、验证和测试拆分的掩码 (train_maskval_masktest_mask) 的键。在 edata_schemas 下,有表示边权重的 weight 属性:

Graph(num_nodes=34,num_edges=78,ndata_schemes={'feat': Scheme(shape=(10,), dtype=tf.float32),'label': Scheme(shape=(), dtype=tf.int32),'train_mask': Scheme(shape=(), dtype=tf.bool),'val_mask': Scheme(shape=(), dtype=tf.bool),'test_mask': Scheme(shape=(), dtype=tf.bool)}edata_schemes={'weight': Scheme(shape=(1,), dtype=tf.int32)}
)

3. 多图数据集

支持图分类任务的数据集包含多个图及其相关标签,每个图一个标签。在本节中,我们将使用一个合成的分子数据集,其中的分子表示为图,模型的任务是预测分子是否有毒性(二分类预测)。
使用 NetworkX 方法 random_regular_graph() 生成具有随机节点数和节点度的合成图。对于每个图的每个节点,添加一个随机的 10 维特征向量。每个图有一个标签 (01),表示分子是否有毒性。需要注意的是,这只是对真实数据的简单模拟。实际数据中,每个图的结构和节点向量的值会对目标变量产生实际影响:

graphs = []
num_graphs = 0
for i in range(10):n = np.random.randint(3, 10)d = np.random.randint(1, 10)if ((n * d) % 2) != 0:continueif n < d:continueg = nx.random_regular_graph(d, n)graphs.append(g)num_graphs += 1if num_graphs >= 4:breakplt.figure(figsize=(10, 10))plt.subplot(2, 2, 1)
nx.draw(graphs[0])plt.subplot(2, 2, 2)
nx.draw(graphs[1])plt.subplot(2, 2, 3)
nx.draw(graphs[2])plt.subplot(2, 2, 4)
nx.draw(graphs[3])plt.plot()
plt.show()

随机构建的合成分子的样本如下:

多图数据集

将一组随机 NetworkX 图转换为 DGL 图数据集以进行图分类。生成 100 个图,并将它们以 DGL 数据集的形式存储在列表中:

from networkx.exception import NetworkXErrorclass SyntheticDataset(DGLDataset):def __init__(self):super().__init__(name="synthetic")def __getitem__(self, i):return self.graphs[i], self.labels[i]def __len__(self):return len(self.graphs)def process(self):self.graphs, self.labels = [], []num_graphs = 0while(True):d = np.random.randint(3, 10)n = np.random.randint(5, 10)if ((n * d) % 2) != 0:continueif n < d:continuetry:g = nx.random_regular_graph(d, n)except NetworkXError:continueg_edges = [edge for edge in g.edges]g_src = [u for u, v in g_edges]g_dst = [v for u, v in g_edges]g_num_nodes = len(g.nodes)label = np.random.randint(0, 2)# create graph and add to list of graphs and labelsdgl_graph = dgl.graph((g_src, g_dst), num_nodes=g_num_nodes)dgl_graph.ndata["feats"] = tf.random.uniform((g_num_nodes, 10), minval=0, maxval=1, dtype=tf.dtypes.float32)self.graphs.append(dgl_graph)self.labels.append(label)num_graphs += 1if num_graphs > 100:breakself.labels = tf.convert_to_tensor(self.labels, dtype=tf.dtypes.int64)

实例化 SyntheticDataset 类:

dataset = SyntheticDataset()
graph, label = dataset[0]    
print(graph)
print("label:", label)

输出 DGL 数据集中第一个图的相关信息。可以看到,数据集中的第一个图有 6 个节点和 15 条边,并包含一个大小为 10 的特征向量(通过 feats 键访问),标签是一个 0 维张量(即标量):

Graph(num_nodes=6, num_edges=15,ndata_schemes={'feats': Scheme(shape=(10,), dtype=tf.float32)}edata_schemes={})
label: tf.Tensor(0, shape=(), dtype=int64)

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(17)——主成分分析详解
TensorFlow深度学习实战(18)——K-means 聚类详解
TensorFlow深度学习实战(19)——受限玻尔兹曼机
TensorFlow深度学习实战(20)——自组织映射详解
TensorFlow深度学习实战(21)——Transformer架构详解与实现
TensorFlow深度学习实战(22)——从零开始实现Transformer机器翻译
TensorFlow深度学习实战(23)——自编码器详解与实现
TensorFlow深度学习实战(24)——卷积自编码器详解与实现
TensorFlow深度学习实战(25)——变分自编码器详解与实现
TensorFlow深度学习实战(26)——生成对抗网络详解与实现
TensorFlow深度学习实战(27)——CycleGAN详解与实现
TensorFlow深度学习实战(28)——扩散模型(Diffusion Model)
TensorFlow深度学习实战(29)——自监督学习(Self-Supervised Learning)
TensorFlow深度学习实战(30)——强化学习(Reinforcement learning,RL)
TensorFlow深度学习实战(31)——强化学习仿真库Gymnasium
TensorFlow深度学习实战(32)——深度Q网络(Deep Q-Network,DQN)
TensorFlow深度学习实战(33)——深度确定性策略梯度
TensorFlow深度学习实战(34)——TensorFlow Probability
TensorFlow深度学习实战(35)——概率神经网络
TensorFlow深度学习实战(36)——自动机器学习(AutoML)
TensorFlow深度学习实战(37)——深度学习的数学原理
TensorFlow深度学习实战(38)——常用深度学习库
TensorFlow深度学习实战(39)——机器学习实践指南
TensorFlow深度学习实战(40)——图神经网络(GNN)

http://www.dtcms.com/a/549578.html

相关文章:

  • Flutter 3.29.0 使用RepaintBoundary或者ScreenshotController出现导出图片渲染上下颠倒问题
  • Flutter---个人信息(4)---实现修改生日日期
  • 不止于加热:管式炉在材料科学与新能源研发中的关键作用
  • 深圳网站建设方案优化深圳发布广告的平台有哪些
  • Go语言中json.RawMessage
  • Pytorch常用函数学习摘录
  • 个人什么取消网站备案铭万做的网站怎么样
  • 2025-10-30 ZYZOJ Star(斯达)模拟赛 hetao1733837的record
  • 百胜中台×OceanBase:打造品牌零售降本增效的数字核心引擎,热门服饰、美妆客户已实践
  • 深度学习调试工具链:从PyTorch Profiler到TensorBoard可视化
  • 不可变借用的规则与限制: 从只读语义到零拷贝架构的 5 000 字深潜
  • 专题三 之 【二分查找】
  • C++进阶: override和final说明符-----继承2中重写的确认官和刹车(制动器)
  • 数据科学每日总结--Day7--数据库
  • opencv 学习: 01 ubuntu20.04 下 opencv 4.12.0 源码编译
  • 满足“国六”标准的通用型故障诊断仪:Q-OBD
  • 上海专业建站公湖南网站建设设计
  • 智慧时空大数据平台:释放时空信息数据价值
  • 线程基本概念
  • MySQL MDL锁阻塞DDL 导致复制线程卡住
  • 智慧管理,赋能美容院新未来
  • Flink做checkpoint迟迟过不去的临时解决思路
  • 网站注册 优帮云wordpress首页静态化
  • [人工智能-大模型-115]:模型层 - 用通俗易懂的语言,阐述神经网络为啥需要多层
  • Actix Web 不是 Nginx:解析 Rust 应用服务器与传统 Web 服务器的本质区别
  • pdf文件上传下载记录
  • 辽阳网站设计中国建设银行的网站.
  • 2. WPF程序打包成一个单独的exe文件
  • 东软专业力考试--Java Web 开发基础
  • 8方向控制圆盘View