TensorFlow深度学习实战(40)——图神经网络(GNN)
TensorFlow深度学习实战(40)——图神经网络(GNN)
- 0. 前言
- 1. 图的基本概念
- 2. 图机器学习
- 3. 图卷积
- 4. 常见图神经网络
- 4.1 图卷积网络
- 4.2 图注意力网络
- 4.3 GraphSAGE
- 4.4 图同构网络
- 4.5 常见图应用
- 5. 图神经网络展望
- 5.1 异构图
- 5.2 时空图
- 小结
- 系列链接
0. 前言
在本节中,我们将介绍图神经网络 (Graph Neural Network
, GNN
),GNN
非常适合处理图数据。许多现实生活中的问题,如社交媒体、生物化学、学术文献等,本质上都可以抽象为图数据。本节中,我们将从数学角度讲解图的基本概念,然后解释图卷积的核心思想,然后我们将介绍一些基于基本图卷积技术变体的 GNN
层。我们将介绍 GNN
的三个主要应用领域,包括节点分类、图分类和链路预测,并使用 TensorFlow
和 Deep Graph Library
(DGL
) 解决这些任务。最后,我们将分析其它类型的图数据,如异构图和动态图。
1. 图的基本概念
图 GGG 是一种数据结构,由一组顶点(也称节点) VVV 和一组连接这些顶点的边 EEE 组成:
G=(V,E)G=(V,E) G=(V,E)
图可以等价地表示为一个大小为 (n,n)(n, n)(n,n) 的邻接矩阵 AAA,其中 nnn 是顶点集合 VVV 的数量,邻接矩阵的元素 A[i,j]A[i, j]A[i,j] 表示顶点 iii 和顶点 jjj 之间的边。因此,如果顶点 iii 和顶点 jjj 之间存在边,则 A[i,j]=1A[i, j] = 1A[i,j]=1,否则为 000。对于加权图,边具有权重,可以使用邻接矩阵中的元素表示边的权重。边可以是有向的或无向的,在无向图中,边没有方向,边连接的两个节点之间的关系是对称的;在有向图中,边有方向,从一个节点指向另一个节点。例如,表示节点 xxx 和 yyy 之间联系的边是无向的,因为 xxx 是 yyy 的朋友意味着 yyy 也是 xxx 的朋友,而有向边可以用于社交网络中,xxx 关注 yyy 并不意味着 yyy 关注 xxx。对于无向图,A[i,j]=A[j,i]A[i, j] = A[j, i]A[i,j]=A[j,i]。
邻接矩阵 AAA 的 nnn 次乘积,即 AnA^nAn 表示了节点之间的 nnn 跳连接。
图与矩阵的等价性是双向的,这表示邻接矩阵可以转换回图表示。由于机器学习 (Machine Learning
, ML
) 方法,使用张量形式的输入数据,这种等价性意味着图可以高效地表示为各种机器学习算法的输入。
每个节点也可以与其特征向量关联,假设特征向量的大小为 fff,那么节点集合 AAA 可以表示为 (n,f)(n, f)(n,f)。边也可以具有特征向量。由于图与矩阵之间的等价性,图通常表示为高效的基于张量的结构。
2. 图机器学习
机器学习 (Machine Learning
, ML
) 任务的目标是学习从输入空间 xxx 到输出空间 yyy 的映射 FFF。传统的机器学习方法需要特征工程来定义合适的特征,而深度学习 (Deep Learning
, DL
) 方法则可以从训练数据中自动提取特征。深度学习利用一个具有随机权重 θ\thetaθ 的模型 MMM,将任务形式化为一个关于参数 θ\thetaθ 的优化问题:
minθL(y,F(x))\underset {\theta}{min}\mathcal L(y,F(x)) θminL(y,F(x))
并使用梯度下降法在多个迭代中更新模型权重,直到参数收敛,图神经网络 (Graph Neural Network
, GNN
) 同样遵循这一基本框架:
θ←θ−η∇θL\theta\leftarrow\theta-\eta\nabla_\theta\mathcal L θ←θ−η∇θL
ML
和 DL
通常针对特定的结构进行优化。例如,在处理表格数据时,选择全连接网络,在处理图像数据时选择卷积神经网络 (Convolutional Neural Network, CNN),在处理文本或时间序列等序列数据时选择循环神经网络 (Recurrent Neural Network, RNN)。
图是拓扑复杂、大小不确定的结构,并且不满足置换不变性(即实例之间不是相互独立的)。因此,需要特殊的工具来处理图数据。Deep Graph Library
(DGL
) 是一个跨平台的处理图数据的库,支持 MX-Net
、PyTorch
和 TensorFlow
,通过可配置的后端实现,是目前最强大且易于使用的图数据处理库之一。
可以使用 pip
命令安装 DGL
:
$ pip install dgl
为了在 Tensorflow
中使用 DGL
,还需要将环境变量 DGLBACKEND
设置为 TensorFlow
。在命令行中,可以通过以下命令实现:
$ export DGLBACKEND=tensorflow
在 notebook
中,可以使用魔法命令实现:
%env DGLBACKEND=tensorflow
如果要在 GPU
环境中使用 DGL
,则需要考虑具体的 cuda
版本,可以在DGL官网上根据实际情况复制命令行进行安装:
3. 图卷积
卷积算子能够有效地以特定方式聚合在二维平面上相邻的像素值,在计算机视觉任务中取得了成功,其一维变体在自然语言处理和音频处理领域同样取得了成功。卷积神经网络连续应用卷积和池化操作,能够学习足够多的全局特征,以成功完成训练任务。
从另一个角度来看,可以将图像(或图像的每个通道)视为一个网格状的图,其中相邻的像素以特定方式相互连接。同样,单词或音频信号序列可以视为一个线性图,其中相邻的词元相互连接。在这两种情况下,深度学习架构逐步在输入图的相邻节点上应用卷积和池化操作,直到学会执行任务。每一步卷积都包含了一个更远距离邻居的信号,例如,第一个卷积合并来自距离为 1
(直接)邻居的信号,第二个卷积合并来自距离为 2
的邻居的信号,依此类推。
下图展示了 CNN
中 3 x 3
卷积与相应的图卷积操作之间的等效关系。卷积算子将卷积核(实质上是一组九个可学习的模型参数)应用于输入,并通过加权求和将其结合起来。通过将像素邻域视为以中间像素为中心的九个节点的图,可以实现相同的效果。
在这种结构上进行图卷积 (graph convolution
) 实际上是对节点特征进行加权求和,这与 CNN
中的卷积算子相同:
CNN
和图卷积的卷积操作对应的方程如下所示。可以看到,在 CNN
中,卷积可以被视为输入像素及其每个邻居的加权线性组合,每个像素通过应用的卷积核得到自身的权重。另一方面,图卷积也是输入节点及其所有邻居的加权线性组合:
CNNhv(l+1)=σ(∑u∈N(v)Wluhu(l)+Blhv(l))Graphhv(l+1)=σ(Wl∑u∈N(v)hu(l)∣N(v)∣+Blhv(l))CNN\ \ \ \ \ h_v^{(l+1)}=\sigma (\sum_{u\in N(v)}W_l^uh_u^{(l)}+B_lh_v^{(l)}) \\ Graph\ \ \ \ \ h_v^{(l+1)}=\sigma (W_l\sum_{u\in N(v)}\frac {h_u^{(l)}}{|N(v)|}+B_lh_v^{(l)}) CNN hv(l+1)=σ(u∈N(v)∑Wluhu(l)+Blhv(l))Graph hv(l+1)=σ(Wlu∈N(v)∑∣N(v)∣hu(l)+Blhv(l))
因此,图卷积可以视作传统卷积的一种变体。接下来,我们将这些卷积组合起来构建不同类型的图卷积网络 (graph convolution network
, GCN
) 层。
4. 常见图神经网络
4.1 图卷积网络
图卷积网络 (Graph Convolution Network
, GCN
) 是 Kipf
和 Welling
提出的图卷积层,作为一种可扩展的半监督学习方法,适用于图结构数据。GCN
对节点特征向量 XXX 和邻接矩阵 AAA 进行的操作,并能够用于邻接矩阵 AAA 中的信息不在数据 XXX 中的情况,例如在引用网络中文档之间的引用链接,或在知识图谱中的关系。
GCN
结合了每个节点的特征向量与其邻居的特征向量,使用随机初始化的权重开始训练。因此,对于每个节点,计算邻居节点特征的总和:
Xi′=update(Xi,aggregate([Xj,j∈N(i)]))X_i'=update(X_i,aggregate([X_j,j\in N(i)])) Xi′=update(Xi,aggregate([Xj,j∈N(i)]))
其中,更新 (update
) 和聚合 (aggregate
) 是不同类型的求和函数。这种对节点特征的投影称为消息传递机制。一次消息传递迭代等同于对每个节点的直接邻居进行图卷积。如果希望从更远的节点获取信息,可以多次重复这一操作。
以下方程描述了 GCN
在层 l+1l+1l+1 处节点i的输出。其中,N(i)N(i)N(i) 是节点i的邻居集合(包括自身),cijc_{ij}cij 是节点度数平方根的乘积,σ\sigmaσ 是激活函数。b(l)b^{(l)}b(l) 项是可选的偏置项:
hi(l+1)=σ(b(l)+∑j∈N(i)1cijhj(l)W(l))h_i^{(l+1)}=\sigma (b^{(l)}+\sum_{j\in N(i)}\frac 1{c_{ij}}{h_j^{(l)}}W^{(l)}) hi(l+1)=σ(b(l)+j∈N(i)∑cij1hj(l)W(l))
接下来,我们将介绍图注意力网络 (Graph Attention Network
, GAT
),GAT
是 GCN
的一种变体,其中系数通过注意力机制进行学习,而并未显式定义。
4.2 图注意力网络
图注意力网络 (Graph Attention Network
, GAT
) 由 Velickovic
等人提出,与 GCN
类似,GAT
执行邻居特征的局部平均。不同之处在于,GAT
不显式指定归一化项 cijc_{ij}cij,而是通过对节点特征的自注意力机制学习。GAT
的归一化项表示为 α\alphaα,它是基于邻居节点的隐藏特征和学习到的注意力向量计算得出。GAT
的核心思想是优先考虑来自相似邻居节点的特征信号,而非来自不相似节点的特征信号。
每个邻居 jjj 在节点 iii 的邻域 N(i)N(i)N(i) 得到其注意力系数向量 αij\alpha_{ij}αij,GAT
在层 i+1i+1i+1 处节点i的输出如下,注意力 α\alphaα 是使用注意力模型和前馈网络计算的:
hi(l+1)=∑j∈N(i)αijW(l)hj(l)αijl=softmax(eijl)h_i^{(l+1)}=\sum_{j\in N(i)}\alpha_{ij}W^{(l)}h_j^{(l)}\\ \alpha_{ij}^l=softmax(e_{ij}^l) hi(l+1)=j∈N(i)∑αijW(l)hj(l)αijl=softmax(eijl)
GCN
和 GAT
架构适用于小到中等规模的网络,而 GraphSAGE
架构更适用于大型网络。
4.3 GraphSAGE
上述卷积方法要求图中的所有节点在训练过程中都是可见的,因为它们是传递式的,无法自然地推广到未见过的节点。GraphSAGE
是一个通用的归纳框架,能够为训练期间未见过的节点生成嵌入,GraphSAGE
通过从节点的局部邻域进行采样和聚合来实现。
GraphSAGE
采样邻居的子集,而不是使用全部邻居。可以通过随机游走定义节点邻域,并汇总重要性分数以确定最佳样本。聚合函数可以是 MEAN
、POOL
和 LSTM
等。MEAN
聚合简单地取邻居向量的元素均值;LSTM
聚合更具表现力,但本质上是顺序的且非对称的,可以应用于从节点邻居的随机排列中派生的无序集合;POOL
聚合中,每个邻居向量独立地通过一个全连接神经网络,并对邻居集中的聚合信息进行最大池化。
以下方程展示了如何从第 lll 层的节点 iii 和其邻居 N(i)N(i)N(i) 生成第 l+1l+1l+1 层的输出:
hN(i)(l+1)=aggregate([hjl∀j∈N(i)])hi(l+1)=σ(W⋅concat(hj(l),hN(i)(l+1)))hi(l+1)=norm(hi(l+1))h_{N(i)}^{(l+1)}=aggregate([h_j^l\forall j\in N(i)])\\ h_i^{(l+1)}=\sigma(W\cdot concat(h_j^{(l)},h_{N(i)}^{(l+1)}))\\ h_i^{(l+1)}=norm(h_i^{(l+1)}) hN(i)(l+1)=aggregate([hjl∀j∈N(i)])hi(l+1)=σ(W⋅concat(hj(l),hN(i)(l+1)))hi(l+1)=norm(hi(l+1))
我们已经了解了如何使用 GNN
处理大规模网络,接下来,我们将探讨如何通过图同构网络 (Graph Isomorphism Network
, GIN
) 来最大化 GNN
的表示能力。
4.4 图同构网络
图同构网络 (Graph Isomorphism Network
, GIN
) 作为一种具有更强表达能力的网络,能够区分在拓扑上相似但不完全相同的一对图。GCN
和 GraphSAGE
无法区分某些图结构,且 SUM
聚合比 MEAN
和 MAX
聚合在区分图结构方面表现更好。因此,GIN
网路使用了一种比 GCN
和 GraphSAGE
更好的邻居聚合表示方式。
以下方程展示了节点 iii 在层 l+1l+1l+1 的输出。其中,函数 fθf_{\theta}fθ 是一个激活函数,aggregate
是一个聚合函数,如 SUM
、MAX
或 MEAN
,而 ϵ\epsilonϵ 是一个在训练过程中的可学习参数:
hi(l+1)=fθ((1+ϵ)hil+aggregate(hjl,j∈N(i)))h_i^{(l+1)}=f_{\theta}((1+\epsilon)h_i^l+aggregate(h_j^l,j\in N(i))) hi(l+1)=fθ((1+ϵ)hil+aggregate(hjl,j∈N(i)))
4.5 常见图应用
GNN
的常见应用,包括节点分类、图分类和链接预测。在相应博客中,使用 TensorFlow
和 DGL
构建并训练 GNN
以完成节点分类、图分类和链接预测任务。
5. 图神经网络展望
我们已讨论了在各种图任务中处理静态同构图,这涵盖了现实世界的多数应用场景。但实际上,某些图可能既不是同构的,也不是静态的,且可能无法轻易简化为静态同构图。在本节中,我们将探讨如何处理异构图和时间序列图。
5.1 异构图
异构图包含不同类型的节点和边。这些不同类型的节点和边可能也包含不同类型的属性,包括可能具有不同维度的表示。例如,包含作者和论文的引用图、包含用户和产品的推荐图以及可以包含多种不同实体的知识图谱。
可以通过为每种边类型手动实现消息传递和更新函数来在异构图上使用消息传递图神经网络 (Message Passing Neural Network
, MPNN
) 框,每种边类型由三元组(源节点类型、边类型和目标节点类型)定义。DGL
使用 dgl.heterograph() API
支持异构图。
与异构图相关的学习任务与同质图类似,包括节点分类和回归、图分类以及链路预测,关系图卷积网络 (Relational GCN
, R-GCN
) 是用于处理异构图的流行架构。
5.2 时空图
时空图用于处理随时间变化的动态图。尽管 GNN
模型主要集中于不随时间变化的静态图,但引入时间维度使我们能够建模社交网络、金融交易和推荐系统中的动态性,这些现象本质上是动态的。在这类系统中,动态性是其重要特性。
动态图可以表示为一系列时间事件,例如节点和边的添加和删除。这些事件流输入到一个编码器网络中,网络学习图中每个节点的时间依赖编码。解码器在编码上进行训练,以支持下游的特定任务,如链接预测。
从高层次来看,时空图网络 (Temporal Graph Network
, TGN
) 编码器基于节点间的交互和随时间的更新创建节点的压缩表示。每个节点的当前状态存储在 TGN
内存中,作为 RNN
的隐藏状态 hth_tht;每个节点 iii 和时间步 ttt 有一个单独的状态向量 ht(t)h_t(t)ht(t)。
类似于在 MPNN
框架中的消息传递功能,计算两个节点 iii 和 jjj 的消息 mim_imi 和 mjm_jmj,使用状态向量及其交互作为输入。消息和状态向量通过一个内存更新器进行结合,通常实现为 RNN
。研究发现,时空图网络在未来边预测和动态节点分类任务上,准确性和速度均优于其对应的静态图神经网络。
小结
在本节中,我们介绍了图神经网络 (Graph Neural Network
, GNN
),图神经网络不仅能够从节点特征中学习,还能从节点之间的交互中学习。我们介绍了图卷积网络的工作原理以及它们与计算机视觉中的卷积之间的相似性。我们介绍了一些常见的图卷积网络,使用 DGL
解决常见的图任务,如节点分类、图分类和链路预测。最后,我们探讨了图神经网络的一些新兴方向,即异构图和时空图。
系列链接
TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(17)——主成分分析详解
TensorFlow深度学习实战(18)——K-means 聚类详解
TensorFlow深度学习实战(19)——受限玻尔兹曼机
TensorFlow深度学习实战(20)——自组织映射详解
TensorFlow深度学习实战(21)——Transformer架构详解与实现
TensorFlow深度学习实战(22)——从零开始实现Transformer机器翻译
TensorFlow深度学习实战(23)——自编码器详解与实现
TensorFlow深度学习实战(24)——卷积自编码器详解与实现
TensorFlow深度学习实战(25)——变分自编码器详解与实现
TensorFlow深度学习实战(26)——生成对抗网络详解与实现
TensorFlow深度学习实战(27)——CycleGAN详解与实现
TensorFlow深度学习实战(28)——扩散模型(Diffusion Model)
TensorFlow深度学习实战(29)——自监督学习(Self-Supervised Learning)
TensorFlow深度学习实战(30)——强化学习(Reinforcement learning,RL)
TensorFlow深度学习实战(31)——强化学习仿真库Gymnasium
TensorFlow深度学习实战(32)——深度Q网络(Deep Q-Network,DQN)
TensorFlow深度学习实战(33)——深度确定性策略梯度
TensorFlow深度学习实战(34)——TensorFlow Probability
TensorFlow深度学习实战(35)——概率神经网络
TensorFlow深度学习实战(36)——自动机器学习(AutoML)
TensorFlow深度学习实战(37)——深度学习的数学原理
TensorFlow深度学习实战(38)——常用深度学习库
TensorFlow深度学习实战(39)——机器学习实践指南