当前位置: 首页 > news >正文

TensorFlow深度学习实战——图分类

TensorFlow深度学习实战——图分类

    • 0. 前言
    • 1. 数据加载与处理
    • 2. 构建图分类模型
    • 3. 模型训练与评估
    • 相关链接

0. 前言

图分类 (Graph classification) 是通过聚合整个图的所有节点特征,并对其应用一个或多个图卷积层来预测整个图的某些属性。例如在药物发现中,对分子进行分类以确定其是否具有特定的治疗特性。在本节中,我们将通过在蛋白质数据集上训练图分类模型。

1. 数据加载与处理

首先,导入所需库:

import dgl.data
import tensorflow as tffrom dgl.nn import GraphConv
from sklearn.model_selection import train_test_split

使用蛋白质 (PROTEINS) 数据集。PROTEINS 数据集包含一组图,每个图都包含节点特征和一个标签。每个图表示一个蛋白质分子,图中的每个节点代表分子中的一个原子,节点特征表示原子的化学属性,标签表示蛋白质分子是否为酶:

"""Loading Data
Label: binary, whether protein is an enzyme or not.
"""
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)print("node feature dimensionality:", dataset.dim_nfeats)
print("number of graph categories:", dataset.gclasses)
print("number of graphs in dataset:", len(dataset))

执行上述代码会在本地下载蛋白质数据集,并打印数据集的相关信息。可以看到,每个节点的特征向量大小为 3,图具有 2 种类别(酶或非酶),数据集中图的数量为 1113

node feature dimensionality: 3
number of graph categories: 2
number of graphs in dataset: 1113

将数据集划分为训练集、验证集和测试集。使用训练集训练图神经网络 (Graph Neural Network, GNN),使用验证集进行验证,并将训练完成的最终模型应用于测试集:

"""Split Dataset into Train and Test"""
tv_dataset, test_dataset = train_test_split(dataset, shuffle=True, test_size=0.2)
train_dataset, val_dataset = train_test_split(tv_dataset, test_size=0.1)
print(len(train_dataset), len(val_dataset), len(test_dataset))

数据集划分为 801 个训练图、89 个验证图和 223 个测试图。由于数据集较大,需要使用小批数据训练网络,以避免超出 GPU 内存。

2. 构建图分类模型

接下来,定义用于图分类的图神经网络 (Graph Neural Network, GNN),由两个堆叠在一起的 GraphConv 层组成,将节点编码为其潜表示。由于模型目标是预测每个图的类别,需要将所有节点表示聚合成一个图的潜表示,可以通过使用 dgl.mean_nodes() 取节点表示的平均值实现:

"""Define Model"""
class GraphClassifier(tf.keras.Model):def __init__(self, in_feats, h_feats, num_classes):super(GraphClassifier, self).__init__()self.conv1 = GraphConv(in_feats, h_feats, activation=tf.nn.relu)self.conv2 = GraphConv(h_feats, num_classes)def call(self, g, in_feat):h = self.conv1(g, in_feat)h = self.conv2(g, h)g.ndata["h"] = hreturn dgl.mean_nodes(g, "h")model = GraphClassifier(dataset.dim_nfeats, 16, dataset.gclasses)
graphs, labels = zip(*[dataset[i] for i in range(16)])
batched_graphs = dgl.batch(graphs)
batched_labels = tf.convert_to_tensor(labels)
pred = model(batched_graphs, batched_graphs.ndata["attr"])
print(pred.shape)

3. 模型训练与评估

设置训练参数,并定义 do_eval() 函数:

"""Training Loop"""
HIDDEN_SIZE = 16
BATCH_SIZE = 16
LEARNING_RATE = 1e-2
NUM_EPOCHS = 20def set_gpu_if_available():device = "/cpu:0"gpus = tf.config.list_physical_devices("GPU")if len(gpus) > 0:device = gpus[0]return devicedevice = set_gpu_if_available()def do_eval(model, dataset):total_acc, total_recs = 0, 0indexes = tf.data.Dataset.from_tensor_slices(range(len(dataset)))indexes = indexes.batch(batch_size=BATCH_SIZE)for batched_indexes in indexes:graphs, labels = zip(*[dataset[i] for i in batched_indexes])batched_graphs = dgl.batch(graphs)batched_labels = tf.convert_to_tensor(labels, dtype=tf.int64)batched_graphs = batched_graphs.to(device)logits = model(batched_graphs, batched_graphs.ndata["attr"])batched_preds = tf.math.argmax(logits, axis=1)acc = tf.reduce_sum(tf.cast(batched_preds == batched_labels, dtype=tf.float32))total_acc += acc.numpy().item()total_recs += len(batched_labels)return total_acc / total_recs

最后,定义并运行训练循环以训练 GraphClassifier 模型。使用学习率为 1e-2Adam 优化器,并使用 SparseCategoricalCrossentropy 作为损失函数,训练 20epoch

with tf.device(device):model = GraphClassifier(dataset.dim_nfeats, HIDDEN_SIZE, dataset.gclasses)optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)train_indexes = tf.data.Dataset.from_tensor_slices(range(len(train_dataset)))train_indexes = train_indexes.batch(batch_size=BATCH_SIZE)for epoch in range(NUM_EPOCHS):total_loss = 0for batched_indexes in train_indexes:with tf.GradientTape() as tape:graphs, labels = zip(*[train_dataset[i] for i in batched_indexes])batched_graphs = dgl.batch(graphs)batched_labels = tf.convert_to_tensor(labels, dtype=tf.int32)batched_graphs = batched_graphs.to(device)logits = model(batched_graphs, batched_graphs.ndata["attr"])loss = loss_fcn(batched_labels, logits)grads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))total_loss += loss.numpy().item()val_acc = do_eval(model, val_dataset)print("Epoch {:3d} | train_loss: {:.3f} | val_acc: {:.3f}".format(epoch, total_loss, val_acc))

输结果如下,在训练过程中,损失减少,验证准确率提高:

模型训练过程监测

最后,在测试数据集上评估训练好的模型:

test_acc = do_eval(model, test_dataset)
print("test accuracy: {:.3f}".format(test_acc))

训练好的 GraphClassifier 模型在测试数据集上的准确率如下:

test accuracy: 0.677

可以看到,模该模型可以成功地将分子识别为酶或非酶,准确率约为 70%

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(17)——主成分分析详解
TensorFlow深度学习实战(18)——K-means 聚类详解
TensorFlow深度学习实战(19)——受限玻尔兹曼机
TensorFlow深度学习实战(20)——自组织映射详解
TensorFlow深度学习实战(21)——Transformer架构详解与实现
TensorFlow深度学习实战(22)——从零开始实现Transformer机器翻译
TensorFlow深度学习实战(23)——自编码器详解与实现
TensorFlow深度学习实战(24)——卷积自编码器详解与实现
TensorFlow深度学习实战(25)——变分自编码器详解与实现
TensorFlow深度学习实战(26)——生成对抗网络详解与实现
TensorFlow深度学习实战(27)——CycleGAN详解与实现
TensorFlow深度学习实战(28)——扩散模型(Diffusion Model)
TensorFlow深度学习实战(29)——自监督学习(Self-Supervised Learning)
TensorFlow深度学习实战(30)——强化学习(Reinforcement learning,RL)
TensorFlow深度学习实战(31)——强化学习仿真库Gymnasium
TensorFlow深度学习实战(32)——深度Q网络(Deep Q-Network,DQN)
TensorFlow深度学习实战(33)——深度确定性策略梯度
TensorFlow深度学习实战(34)——TensorFlow Probability
TensorFlow深度学习实战(35)——概率神经网络
TensorFlow深度学习实战(36)——自动机器学习(AutoML)
TensorFlow深度学习实战(37)——深度学习的数学原理
TensorFlow深度学习实战(38)——常用深度学习库
TensorFlow深度学习实战(39)——机器学习实践指南
TensorFlow深度学习实战(40)——图神经网络(GNN)

http://www.dtcms.com/a/473396.html

相关文章:

  • SAP MM采购信息记录维护接口分享
  • 网站搭建装修风格大全2021新款简约
  • Mysql初阶第八讲:Mysql表的内外连接
  • SpringCloud 入门 - Gateway 网关与 OpenFeign 服务调用
  • uniapp 选择城市(城市列表选择)
  • AR小白入门指南:从零开始开发增强现实应用
  • 02_k8s资源清单
  • 2025年渗透测试面试题总结-109(题目+回答)
  • uniapp配置自动导入uni生命周期等方法
  • flink的Standalone-HA模式安装
  • Flink时态表关联:实现数据“时间旅行”的终极方案
  • 做哪类英文网站赚钱wordpress 页面 列表
  • nginx + spring cloud + redis + mysql + ELFK 部署
  • 【黑马点评 - 实战篇01】Redis项目实战(Windows安装Redis6.2.6 + 发送验证码 + 短信验证码登录注册 + 拦截器链 - 登录校验)
  • 汕头市通信建设管理局网站二网站手
  • FreeRTOS小记
  • 数据结构实战:顺序表全解析 - 从零实现到性能分析
  • 【C++进阶】继承上 概念及其定义 赋值兼容转换 子类默认成员函数的详解分析
  • 华为matebook16s 2022禁用触摸板和触摸屏操作
  • GridRow 和 Column 有啥区别
  • 030159网站建设与维护中国科技成就素材
  • Echarts 5.6.0 Grid 坐标系中 Y 轴可视化的优化之路
  • Java 线程池如何知道一个线程的任务已经执行完成
  • JVM字节码与类的加载(一):类的加载过程详解
  • 强军网网站建设网站需要备案才能建设吗
  • 耄大厨——AI厨师智能体(3-工具调用)
  • (二)黑马React(导航/账单项目)
  • SA-LSTM
  • 【Java并发】深入理解synchronized
  • Docker 安装 Harbor 教程