当前位置: 首页 > news >正文

TensorFlow深度学习实战——自定义图神经网络层

TensorFlow深度学习实战——自定义图神经网络层

    • 0. 前言
    • 1. 自定义层和消息传递
    • 2. 使用 MPNN 实现自定义 GraphSAGE 层
    • 相关链接

0. 前言

我们已经学习了如何构建并训练图神经网络 (Graph Neural Network, GNN) 以解决常见的图机器学习任务。实现过程中,我们选择使用预构建的 Deep Graph Library (DGL) 图卷积层,但针对具体问题,我们可能会需要 DGL 库中未提供的网络层。DGL 提供了一个消息传递 API,用于构建自定义图神经网络层。在本节中,我们介绍如何使用消息传递 API 构建自定义图卷积层。

1. 自定义层和消息传递

尽管 Deep Graph Library (DGL) 提供了许多预构建的图神经网络层,但有时这些网络层可能无法完全满足需求,因此需要构建自定义层。
研究表明,所有图神经网络层都基于图中节点之间消息传递的共同基本概念。因此,为了构建自定义图神经网络 (Graph Neural Network, GNN) 层,我们需要了解消息传递的工作原理,也称为消息传递神经网络 (Message Passing Neural Network, MPNN) 框架:
mu→v(l)=M(l)(hv(l−1),hu(l−1),eu→v(l−1))mv(l)=∑u∈N(v)mu→v(l)hv(l)=U(l)(hv(l−1),mv(l))m_{u\rightarrow v}^{(l)}=M^{(l)}(h_v^{(l-1)},h_u^{(l-1)},e_{u\rightarrow v}^{(l-1)})\\ m_v^{(l)}=\sum_{u\in N(v)}m_{u\rightarrow v}^{(l)}\\ h_v^{(l)}=U^{(l)}(h_v^{(l-1)},m_v^{(l)}) muv(l)=M(l)(hv(l1),hu(l1),euv(l1))mv(l)=uN(v)muv(l)hv(l)=U(l)(hv(l1),mv(l))
每个图中的节点 uuu 都有一个隐藏状态(初始时是其特征向量),用 huh_uhu 表示。对于每对相邻节点 uuuvvv,即由边 eu→ve_{u\rightarrow v}euv 连接的节点,应用一个消息函数 MMM,消息函数 MMM 应用于图中的每个节点。然后,将 MMM 的输出和所有邻居节点的输出聚合以生成消息 mmm 。其中 ∑\sum 被称为聚合函数 (reduce function)。需要注意的是,尽管这里用求和符号 ∑\sum 表示聚合函数,但它可以是任意聚合函数。最后,使用获得的消息和节点的先前状态更新节点 vvv 的隐藏状态,此步骤中应用的函数 UUU 称为更新函数。
消息传递算法重复指定次数。之后,进入读取阶段,从每个节点提取特征向量,以表示整个图。例如,在节点分类中,节点的最终特征向量可能表示节点的类别。
在本节中,我们将使用 MPNN 框架实现 GraphSAGE 层。虽然 DGL 提供了 dgl.nn.SAGEConv,但本节主要目的是展示如何使用 MPNN 创建自定义图神经网络层。GraphSAGE 层的消息传递步骤如下所示:
hN(v)k←AVG(huk−1,∀u∈N(v))hvk←RELU(Wk⋅CONCAT(hvk−1,hN(v)k))h_{N(v)}^k\leftarrow AVG(h_u^{k-1},\forall u \in N(v))\\ h_v^k\leftarrow RELU(W^k\cdot CONCAT(h_v^{k-1},h_{N(v)}^k)) hN(v)kAVG(huk1,uN(v))hvkRELU(WkCONCAT(hvk1,hN(v)k))

2. 使用 MPNN 实现自定义 GraphSAGE 层

DGL 函数 update_all 通过传递 message_fnreduce_fn 参数指定聚合函数,而 tf.concatDense 层则表示最终的更新函数:

import dgl
import dgl.data
import dgl.function as fn
import tensorflow as tf"""Message passing and GNNs"""
class CustomGraphSAGE(tf.keras.layers.Layer):"""Graph convolution module used by the GraphSAGE model.Parameters----------in_feat : intInput feature size.out_feat : intOutput feature size."""def __init__(self, in_feat, out_feat):super(CustomGraphSAGE, self).__init__()# A linear submodule for projecting the input and neighbor feature to the output.self.linear = tf.keras.layers.Dense(out_feat, activation=tf.nn.relu)def call(self, g, h):"""Forward computationParameters----------g : GraphThe input graph.h : TensorThe input node feature."""with g.local_scope():g.ndata["h"] = h# update_all is a message passing API.g.update_all(message_func=fn.copy_u('h', 'm'),reduce_func=fn.mean('m', 'h_N'))h_N = g.ndata['h_N']h_total = tf.concat([h, h_N], axis=1)return self.linear(h_total)

其中,update_all 函数的 message_func 参数用于将节点的当前特征向量复制到消息向量 m 中,然后平均每个节点邻域中的所有消息向量,从而实现了 GraphSAGE 方程。
一旦计算出邻域向量 h_N,它会与输入特征向量 h 连接,然后通过一个带 ReLU 激活的 Dense 层,这与 GraphSAGE 方程描述一致。至此,我们已经使用 CustomGraphSAGE 对象实现了 GraphSAGE 层。
接下来,将其放入一个图神经网络中,以查看效果。定义 CustomGNN 模型类,包含了两个自定义的 SAGEConv 网络层:

class CustomGNN(tf.keras.Model):def __init__(self, g, in_feats, h_feats, num_classes):super(CustomGNN, self).__init__()self.g = gself.conv1 = CustomGraphSAGE(in_feats, h_feats)self.relu1 = tf.keras.layers.Activation(tf.nn.relu)self.conv2 = CustomGraphSAGE(h_feats, num_classes)def call(self, in_feat):h = self.conv1(self.g, in_feat)h = self.relu1(h)h = self.conv2(self.g, h)return h

使用 CustomGNN 模型对 CORA 数据集进行节点分类:

"""Training Loop"""
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]LEARNING_RATE = 1e-2
NUM_EPOCHS = 200def evaluate(model, features, labels, mask, edge_weights=None):if edge_weights is None:logits = model(features, training=False)else:logits = model(features, edge_weights, training=False)logits = logits[mask]labels = labels[mask]indices = tf.math.argmax(logits, axis=1)acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))return acc.numpy().item()def train(g, model, optimizer, loss_fcn, num_epochs, use_edge_weights=False):features = g.ndata["feat"]labels = g.ndata["label"]if use_edge_weights:edge_weights = g.edata["w"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]for epoch in range(num_epochs):with tf.GradientTape() as tape:if not use_edge_weights:logits = model(features)else:logits = model(features, edge_weights)loss_value = loss_fcn(labels[train_mask], logits[train_mask])grads = tape.gradient(loss_value, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))if not use_edge_weights:acc = evaluate(model, features, labels, val_mask)else:acc = evaluate(model, features, labels, val_mask, edge_weights=edge_weights)if epoch % 10 == 0:print("Epoch {:5d} | loss: {:.3f} | val_acc: {:.3f}".format(epoch, loss_value.numpy().item(), acc))if not use_edge_weights:acc = evaluate(model, features, labels, test_mask)else:acc = evaluate(model, features, labels, test_mask, edge_weights=edge_weights)print("Test accuracy: {:.3f}".format(acc))model = CustomGNN(g, g.ndata['feat'].shape[1], 16, dataset.num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)train(g, model, optimizer, loss_fcn, NUM_EPOCHS)

以上代码假设图是一个无权重图,即节点之间的边具有相同的权重。这种情况在 CORA 数据集中是成立的,其中每条边代表一篇论文对另一篇论文的引用。而在其它场景中,边的权重可能基于边被调用的次数,例如,在社交网络中,节点可以表示用户,边可以表示用户之间的关系,边的权重可以表示用户互动频率。
为了处理加权边,需要对消息函数进行调整,以便边的权重发挥作用。也就是说,如果一个边在节点 uuu 和邻居节点 vvv 之间出现了 kkk 次,应该将这条边考虑 kkk 次。实现自定义 GraphSAGE 层处理加权边:

"""More customization"""
class CustomWeightedGraphSAGE(tf.keras.layers.Layer):"""Graph convolution module used by the GraphSAGE model with edge weights.Parameters----------in_feat : intInput feature size.out_feat : intOutput feature size."""def __init__(self, in_feat, out_feat):super(CustomWeightedGraphSAGE, self).__init__()# A linear submodule for projecting the input and neighbor feature to the output.self.linear = tf.keras.layers.Dense(out_feat, activation=tf.nn.relu)def call(self, g, h, w):"""Forward computationParameters----------g : GraphThe input graph.h : TensorThe input node feature.w : TensorThe edge weight."""with g.local_scope():g.ndata['h'] = hg.edata['w'] = wg.update_all(message_func=fn.u_mul_e('h', 'w', 'm'),reduce_func=fn.mean('m', 'h_N'))h_N = g.ndata['h_N']h_total = tf.concat([h, h_N], axis=1)return self.linear(h_total)class CustomWeightedGNN(tf.keras.Model):def __init__(self, g, in_feats, h_feats, num_classes):super(CustomWeightedGNN, self).__init__()self.g = gself.conv1 = CustomWeightedGraphSAGE(in_feats, h_feats)self.relu1 = tf.keras.layers.Activation(tf.nn.relu)self.conv2 = CustomWeightedGraphSAGE(h_feats, num_classes)def call(self, in_feat, edge_weights):h = self.conv1(self.g, in_feat, edge_weights)h = self.relu1(h)h = self.conv2(self.g, h, edge_weights)return h

CustomWeightedGraphSAGE 层接收一个额外的边属性 w,该属性包含边权重,可以通过以下方法在 CORA 数据集中模拟:

g.edata["w"] = tf.cast(tf.random.uniform((g.num_edges(), 1), minval=3, maxval=10, dtype=tf.int32),dtype=tf.float32)
g.edata["w"]model = CustomWeightedGNN(g, g.ndata['feat'].shape[1], 16, dataset.num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)train(g, model, optimizer, loss_fcn, NUM_EPOCHS, use_edge_weights=True)"""Even more customization by User Defined Function
Not tensorflow related, but useful to mention.
"""
def u_mul_e_udf(edges):return {'m' : edges.src['h'] * edges.data['w']}def mean_udf(nodes):return {'h_N': nodes.mailbox['m'].mean(1)}

CustomWeightedGraphSAGE 中,message_func 从简单地将特征向量 h 复制到消息向量 m,变为将 hw 相乘以生成消息向量 m。其他部分与 CustomGraphSAGE 相同。新的 CustomWeightedGraphSAGE 层可以直接替换 CustomGraphSAGE 进行调用。

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(17)——主成分分析详解
TensorFlow深度学习实战(18)——K-means 聚类详解
TensorFlow深度学习实战(19)——受限玻尔兹曼机
TensorFlow深度学习实战(20)——自组织映射详解
TensorFlow深度学习实战(21)——Transformer架构详解与实现
TensorFlow深度学习实战(22)——从零开始实现Transformer机器翻译
TensorFlow深度学习实战(23)——自编码器详解与实现
TensorFlow深度学习实战(24)——卷积自编码器详解与实现
TensorFlow深度学习实战(25)——变分自编码器详解与实现
TensorFlow深度学习实战(26)——生成对抗网络详解与实现
TensorFlow深度学习实战(27)——CycleGAN详解与实现
TensorFlow深度学习实战(28)——扩散模型(Diffusion Model)
TensorFlow深度学习实战(29)——自监督学习(Self-Supervised Learning)
TensorFlow深度学习实战(30)——强化学习(Reinforcement learning,RL)
TensorFlow深度学习实战(31)——强化学习仿真库Gymnasium
TensorFlow深度学习实战(32)——深度Q网络(Deep Q-Network,DQN)
TensorFlow深度学习实战(33)——深度确定性策略梯度
TensorFlow深度学习实战(34)——TensorFlow Probability
TensorFlow深度学习实战(35)——概率神经网络
TensorFlow深度学习实战(36)——自动机器学习(AutoML)
TensorFlow深度学习实战(37)——深度学习的数学原理
TensorFlow深度学习实战(38)——常用深度学习库
TensorFlow深度学习实战(39)——机器学习实践指南
TensorFlow深度学习实战(40)——图神经网络(GNN)

http://www.dtcms.com/a/560584.html

相关文章:

  • 车陂手机网站开发学校网站群建设必要
  • 【Elasticsearch入门到落地】18、Elasticsearch实战:Java API详解高亮、排序与分页
  • Java Web学习 第1篇前端基石HTML 入门与核心概念解析
  • Kafka4.1.0 队列模式尝鲜
  • transformer记录一(输入步骤讲解)
  • 做生存分析的网站有哪些网站背景怎么弄
  • Tomcat 新手避坑指南:环境配置 + 启动问题 + 乱码解决全流程
  • 整理、分类、总结与介绍Vue前端开发日常常用的第三方库/框架/插件-收藏
  • 第九天~在Arxml中定义一对XCP-PDU用于测量标定
  • Tomcat 配置问题速查表
  • 第九天~AUTOSAR网络管理NM-PDU详解:在Arxml中定义唤醒节点的NM-PDU
  • 在centos 7上配置FIP服务器的详细教程!!!
  • 做网站三网多少钱wordpress 贴吧主题
  • 无锡网站建设营销型诸城公司做网站
  • 【Docker】容器网络探索(二):实战理解 host 网络
  • 《数据结构风云》:二叉树遍历的底层思维>递归与迭代的双重视角
  • Java EE初阶 --多线程2
  • 论文精读(七):结合大语言模型和领域知识库的证券规则规约方法
  • Linux shell sed 命令基础
  • 选 Redis Stream 还是传统 MQ?队列选型全攻略(适用场景、优缺点与实践建议)
  • 【JVM】详解 Java内存模型(JMM)
  • 做网站工作室广告网站建设
  • 小语种网站制作广州网站建设哪里有
  • 广州学做网站上饶网站建设多少钱
  • GO写的http服务,清空cookie
  • 响应式企业网站模板望京网站建设公司
  • 最新聊天记录做图网站ip软件点击百度竞价推广
  • 关于学校网站建设申请报告深圳市网络seo推广价格
  • 公司网站后台怎么上传图片百度西安分公司地址
  • Go语言设计模式:组合模式详解