当前位置: 首页 > news >正文

TensorFlow深度学习实战——节点分类

TensorFlow深度学习实战——节点分类

    • 0. 前言
    • 1. 数据分析
    • 2. 构建节点分类模型
    • 3. 模型训练与评估
    • 相关链接

0. 前言

节点分类是图数据领域的一个常见任务。在这一任务中,模型的训练目标是预测节点的类别。非图分类方法仅使用节点特征向量实现节点分类,早期的图神经网络 (Graph Neural Network, GNN) 方法(如 DeepWalknode2vec )仅使用邻接矩阵(连接信息)实现节点分类,而 GNN 能够同时利用节点特征向量和连接信息进行节点分类。

1. 数据分析

本质上,节点分类的思路是对图中的所有节点应用一个或多个图卷积,将节点的特征向量投影到相应的输出类别向量中,以预测节点的类别。本节,将使用 CORA 数据集训练节点分类模型,CORA 数据集是一个包含 2,708 篇科学论文的集合,每篇论文可以分类为七个类别之一。这些论文以及它们之间的引用关系构成了一个包含 5,429 个链接的引文网络,每篇论文由一个大小为 1,433 的词向量描述。

(1) 首先,导入所需库:

import dgl
import dgl.data
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from dgl.nn.tensorflow import GraphConv

(2) 加载 CORA 数据集:

dataset = dgl.data.CoraGraphDataset()

(3) 第一次调用时,它会记录下载和提取到本地文件的过程。完成后,它会输出一些有关 CORA 数据集的统计信息。可以看到,图中有 2,708 个节点和 10,566 条边。每个节点都有一个大小为 1,433 的特征向量,节点被分类为七个类别之一,此外,有 140 个训练样本、500 个验证样本和 1,000 个测试样本:

  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done saving data into cached files.

CORA 数据集是一个单一的引文图,可以通过 len(dataset) 来验证,将返回 1。这意味着模型将处理 dataset[0] 提供的图,节点特征作为键值对包含在字典 dataset[0].ndata 中,边特征则在 dataset[0].edata 中。ndata 包含键 train_maskval_masktest_mask,这些是布尔掩码,表示哪些节点属于训练、验证和测试集,还有一个 feat 键,包含图中每个节点的特征向量。

2. 构建节点分类模型

构建一个包含两个 GraphConv 层的 NodeClassifier 网络。每一层将通过聚合邻居信息计算新的节点表示。GraphConv 层是 tf.keras.layers.Layer 对象,因此可以进行堆叠。第一个 GraphConv 层将输入特征(大小为 1,433 )投影到大小为 16 的隐藏特征向量上,第二个 GraphConv 层将隐藏特征向量投影到大小为 2 的输出类别向量,从中获取类别:

"""Defining a Graph Convolutional Network (GCN)"""
class NodeClassifier(tf.keras.Model):def __init__(self, g, in_feats, h_feats, num_classes):super(NodeClassifier, self).__init__()self.g = gself.conv1 = GraphConv(in_feats, h_feats, activation=tf.nn.relu)self.conv2 = GraphConv(h_feats, num_classes)def call(self, in_feat):h = self.conv1(self.g, in_feat)h = self.conv2(self.g, h)return hg = dataset[0]
model = NodeClassifier(g, g.ndata["feat"].shape[1], 16, dataset.num_classes)

需要注意的是,GraphConv 只是构建 NodeClassifier 模型的一种图神经网络层,DGL 提供了多种图卷积层,可以用来替换 GraphConv

3. 模型训练与评估

(1)CORA 数据集上训练模型。使用 AdamW 优化器,AdamW 优化器是 Adam 优化器的变体,能够得到更好的模型泛化能力,学习率为 1e-2,权重衰减为 5e-4,训练 200epoch。同时检测是否有可用的 GPU,如果有,将图数据转移到 GPU 上。如果检测到 GPUTensorFlow 会自动将模型转移到 GPU 上:

"""Training the GCN"""
device = "/cpu:0"
gpus = tf.config.list_physical_devices("GPU")
if len(gpus) > 0:device = gpus[0]
g = g.to(device)

(2) 定义 do_eval() 方法,根据特征计算模型在(由布尔掩码拆分的)测试数据集上的准确率:

def do_eval(model, features, labels, mask):logits = model(features, training=False)logits = logits[mask]labels = labels[mask]preds = tf.math.argmax(logits, axis=1)acc = tf.reduce_mean(tf.cast(preds == labels, dtype=tf.float32))return acc.numpy().item()

(3) 最后,定义训练循环:

NUM_HIDDEN = 16
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 5e-4
NUM_EPOCHS = 200with tf.device(device):feats = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]in_feats = feats.shape[1]n_classes = dataset.num_classesn_edges = dataset[0].number_of_edges()model = NodeClassifier(g, in_feats, NUM_HIDDEN, n_classes)loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)optimizer = tf.keras.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)best_val_acc, best_test_acc = 0, 0history = []for epoch in range(NUM_EPOCHS):with tf.GradientTape() as tape:logits = model(feats)loss = loss_fcn(labels[train_mask], logits[train_mask])grads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))val_acc = do_eval(model, feats, labels, val_mask)history.append((epoch + 1, loss.numpy().item(), val_acc))if epoch % 10 == 0:print("Epoch {:3d} | train loss: {:.3f} | val acc: {:.3f}".format(epoch, loss.numpy().item(), val_acc))epochs = [epoch for epoch, _, _ in history]
losses = [loss for _, loss, _ in history]
val_accs = [val_acc for _, _, val_acc in history]plt.subplot(2, 1, 1)
plt.plot(epochs, losses)
plt.xlabel("epochs")
plt.ylabel("train loss")plt.subplot(2, 1, 2)
plt.plot(epochs, val_accs)
plt.xlabel("epochs")
plt.ylabel("val acc")plt.tight_layout()
plt.show()

运行代码,训练运行过程输出如下,可以看到训练损失从 1.9 降低到 0.02,验证准确率从 0.13 提高到 0.78

训练过程

(4) 评估训练好的节点分类器在测试数据集上的表现:

test_acc = do_eval(model, feats, labels, test_mask)
print("Test acc: {:.3f}".format(test_acc))

打印出模型在测试数据集上的准确率如下:

Test acc: 0.779

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(17)——主成分分析详解
TensorFlow深度学习实战(18)——K-means 聚类详解
TensorFlow深度学习实战(19)——受限玻尔兹曼机
TensorFlow深度学习实战(20)——自组织映射详解
TensorFlow深度学习实战(21)——Transformer架构详解与实现
TensorFlow深度学习实战(22)——从零开始实现Transformer机器翻译
TensorFlow深度学习实战(23)——自编码器详解与实现
TensorFlow深度学习实战(24)——卷积自编码器详解与实现
TensorFlow深度学习实战(25)——变分自编码器详解与实现
TensorFlow深度学习实战(26)——生成对抗网络详解与实现
TensorFlow深度学习实战(27)——CycleGAN详解与实现
TensorFlow深度学习实战(28)——扩散模型(Diffusion Model)
TensorFlow深度学习实战(29)——自监督学习(Self-Supervised Learning)
TensorFlow深度学习实战(30)——强化学习(Reinforcement learning,RL)
TensorFlow深度学习实战(31)——强化学习仿真库Gymnasium
TensorFlow深度学习实战(32)——深度Q网络(Deep Q-Network,DQN)
TensorFlow深度学习实战(33)——深度确定性策略梯度
TensorFlow深度学习实战(34)——TensorFlow Probability
TensorFlow深度学习实战(35)——概率神经网络
TensorFlow深度学习实战(36)——自动机器学习(AutoML)
TensorFlow深度学习实战(37)——深度学习的数学原理
TensorFlow深度学习实战(38)——常用深度学习库
TensorFlow深度学习实战(39)——机器学习实践指南
TensorFlow深度学习实战(40)——图神经网络(GNN)

http://www.dtcms.com/a/495201.html

相关文章:

  • scipy的统计学库(4):用rv_histogram类实现随机抽样
  • Element Plus el-table 默认勾选行的方法
  • Linux系统函数opendir、closedir、readdir详解及案例(自定义ls工具)
  • 便捷网站建设哪家便宜网站建没有前景
  • 接口测试 | Postman的高级用法的测试使用
  • TR3--Transformer之pytorch复现
  • Traccar本地文件包含漏洞(CVE-2025-61666)
  • 建站网站推荐icp域名备案查询系统
  • 智能美颜引擎:美颜SDK如何实现自适应芯片性能优化
  • Java中的boolean与Boolean
  • Flutter高级进阶教程(视频教程)
  • Rocketmq 分布式事务 两阶段提交
  • 骑行,团骑和独骑冲突吗?
  • 对网站和网页的认识鞍山信息网便民信息
  • 《算法通关指南---C++编程篇(2)》
  • 【论文速递】2025年第29周(Jul-13-19)(Robotics/Embodied AI/LLM)
  • 网站 模板更改网站备案
  • VR反诈一体机-VR预防诈骗模拟系统-VR防诈骗体验馆方案
  • 大型网站seo课程沈阳关键词优化费用
  • Kubernetes PVC 扩容完全指南:静态迁移 vs 动态扩容
  • 【题解】B2613【深基1.习5】打字速度
  • Elastic DevRel 通讯 — 2025 年 10 月
  • Java面试基础题
  • 博客标题:快速解决 VS Code 终端运行 petalinux-config 界面显示错乱问题
  • 强化学习【Monte Carlo Learning][MC Basic 算法]
  • 杭州网站开发制作公司小程序源码出售
  • 从0到1学习Qt -- 创建项目
  • dw做网站基础wap网站开发价格
  • 【实时Linux实战系列】实时应用的多版本共存与无缝升级
  • Linux小课堂: 文件操作核心命令深度解析(cp、mv 与 rm 命令)