当前位置: 首页 > news >正文

图神经网络(篇一)-GraphSage

知识框架

简介

GraphSage(Graph Sample and aggregate)是2017年斯坦福大学提出的一种基于图的inductive(归纳)学习方法。利用节点特征信息和结构信息,从顶点的局部邻居采样并聚合邻居节点和顶点的特征,获取到顶点的Graph Embedding。相比已有的结果,可以为未见过的顶点生成embedding。除此之外,对于节点分类和链接预测问题的表现也比较突出。

知识先导

transductive learning vs inductive learning

  1. transductive learning:直推学习。从特殊到特殊,仅考虑当前数据。在图中学习目标是学习目标是直接生成当前节点的embedding,例如DeepWalk、LINE,把每个节点embedding作为参数,并通过SGD优化,又如GCN,在训练过程中使用图的拉普拉斯矩阵进行计算

  2. inductive learning:归纳学习。平时所说的一般的机器学习任务,从特殊到一般,目标是在未知数据上也有区分性。

对比GCN

模型

基本思想

优势

不足

GCN

把一个节点在图中的高维邻接信息降维到一个低维的向量表示

可以捕捉到graph的全局信息, 从而可以更好地表示node

是Transductive Learning(直推学习),训练的时候会用到验证集、测试集的信息必须要把全部的节点参与训练才能获得node embedding,对于新节点也需要重新训练,才可以产生节点embedding,从而影响了产出的速率

GraphSage

利用节点特征信息和结构信息,从顶点的局部邻居采样并聚合邻居节点和顶点的特征,获取到顶点的Graph Embedding

可以利用已有的节点信息产生新节点的embedding,同时可以保证生产的高效性可以应用在大规模的图学习上

聚合计算的时候没有考虑到邻居的重要性(GAT模型进行了相关的完善)只涉及了无向图,对于有向图需要进一步研究

已有工作

可以分为三方面的工作,如下所示:

Factorization-based embedding approaches

即基于随机游走和矩阵分解学习节点的embedding。代表工作:

  • transductive

  1. Grarep: Learning graph representations with global structural information. In KDD, 2015

  2. node2vec: Scalable feature learning for networks. In KDD, 2016

  3. Deepwalk: Online learning of social representations. In KDD, 2014

  4. Line: Large-scale information network embedding. In WWW, 2015

  5. Structural deep network embedding. In KDD, 2016

  • inductive

  1. Yang et al.的Planetoid-I算法

  • 不足

分类

不足

transductive

直接训练单个节点的节点embedding,本质上是transductive,对于新的节点,需要大量的额外训练(如随机梯度降低)

inductive

是半监督学习; Planetoid-I在推断的时候不使用任何图结构信息,而在训练的时候将图结构做为一种正则化的形式

图结构监督学习方法

如Graph kernel。代表工作:

  1. Discriminative embeddings of latent variable models for structured data. In ICML, 2016

  2. A new model for learning in graph domains

  3. Gated graph sequence neural networks. In ICLR, 2015

  4. The graph neural network model

  • 不足

上述工作的目的:对整个图(或子图)进行分类的。但是,本文的工做的重点是为单个节点生成有用的表征,所实现的目的不同

GCN的相关工作

  1. Spectral networks and locally connected networks on graphs. In ICLR, 2014

  2. Convolutional neural networks on graphs with fast localized spectral filtering. In NIPS, 2016

  3. Convolutional networks on graphs for learning molecular fingerprints. In NIPS,2015

  4. Semi-supervised classification with graph convolutional networks. In ICLR, 2016

  5. Learning convolutional neural networks for graphs. In ICML, 2016

  • 不足

上述工作中的大多数不能扩展到大型图,或者设计用于全图分类(或者二者都是)

解决的问题

问题:大规模动态图结构下的节点embedding生成问题

解决方案

整体实现逻辑

  • 整体流程图

  • 如上图所示,有3个环节

  1. 对邻居随机采样。目的:降低计算复杂度(图中一跳邻居采样数=3,二跳邻居采样数=5)

  2. 生成目标节点embedding。先聚合2跳邻居特征,生成一跳邻居embedding,再聚合一跳邻居embedding,生成目标节点embedding,从而获得二跳邻居信息

  3. 将embedding作为全连接层的输入,预测目标节点的标签

模型结构和方案设计

  • 模型结构

  1. 见方案的“整体实现框图”

  • 方案设计

(1)前向传播算法

1)伪代码

2)说明

  1. K:是聚合器的数量,也是权重矩阵的数量,也是网络的层数

  2. 4-5行是核心代码,介绍卷积层操作:聚合与节点v相连的邻居(采样)k-1层的embedding,得到第k层邻居聚合特征 ,与节点v第k-1层embedding 拼接,并通过全连接层转换,得到节点v在第k层的embedding

(2)采样

  1. GraphSage采用的定长采样

  2. 采样实现:定义需要的邻居数量S,采用有放回的重采样/负采样方法达到S

  3. 定长的原因:将节点和邻居拼成Tensor送到GPU中进行批训练

(3)聚合器

1)论文中介绍了4种满足排序不变量的聚合函数:平均、归纳式、LSTMPooling聚合。(因为邻居没有顺序,聚合函数需要满足排序不变量的特性,即输入顺序不会影响函数结果)

2)具体如下

a. 平均聚合

先对邻居embedding中每个维度取平均,然后与目标节点embedding拼接后进行非线性转换。

b. 归纳式聚合

直接对目标节点和所有邻居emebdding中每个维度取平均(替换伪代码中第5、6行),后再非线性转换

c. LSTM聚合

LSTM函数不符合“排序不变量”的性质,需要先对邻居随机排序,然后将随机邻居序列embedding(如下所示)作为LSTM输入

d. Pooling聚合

先对每个邻居节点上一层embedding进行非线性转换(等价单个全连接层,每一维度代表在某方面的表示(如信用情况)),再按维度应用 max/mean pooling,捕获邻居集上在某方面的突出的/综合的表现以此表示目标节点embedding

(4)损失函数定义

根据具体的应用情况,可以使用基于图的无监督损失和有监督损失

  • 无监督损失

1)希望节点u与“邻居”v的embedding也相似(对应公式第一项),而与“没有交集”的节点不相似(对应公式第二项)。

2)说明

a. zu为节点u通过GraphSAGE生成的embedding

b. 节点v是节点u随机游走访达“邻居”

c. vn~Pn(u)表示负采样:节点vn是从节点u的负采样分布Pn采样得到的,Q为采样样本数。

d. embedding之间相似度通过向量点积计算得到

  • 有监督损失

1)无监督损失函数的设定来学习节点embedding 可以供下游多个任务使用若仅使用在特定某个任务上,则可以替代上述损失函数符合特定任务目标,如交叉熵

(5)参数学习

  1. 通过前向传播得到节点u的embedding ,然后梯度下降(实现使用Adam优化器) 进行反向传播优化参数和聚合函数内参数

模型效果

实验目的

  1. 比较GraphSAGE 相比baseline 算法的提升效果;

  2. 比较GraphSAGE的不同聚合函数

数据集和任务

  1. Citation 论文引用网络(节点分类)

  2. Reddit web论坛 (节点分类)

  3. PPI 蛋白质网络 (graph分类)

比较方法

  1. 下述的方法进行比较时均采用LR进行分类

  2. 比较的方法

    1. Random:随机分类器

    2. Raw Features:手工特征(非图特征)

    3. DeepWalk:图拓扑特征

    4. DeepWalk+features:deepwalk+手工特征

    5. GraphSAGE四个变种:无监督生成embedding输入给LR,然后端到端有监督训练

GraphSAGE的参数设置

KS激活函数优化器每个节点随机游走的步长和步数负采样其它
2S1=25 S2=10ReLuAdam步长:5 步数:50参考word2vec, 按平滑degree进行, 对每个节点采样20个为了保证公平性, 所有版本都采用相同的minibatch迭代器、损失函数、邻居抽样器

运行时间和参数敏感性

  • 通过对比,选择合适的参数,具体结果如下图所示

  • 说明

  1. 计算时间:下图A中GraphSAGE中LSTM训练速度最慢,但相比DeepWalk,GraphSAGE在预测时间减少100-500倍(因为对于未知节点,DeepWalk要重新进行随机游走以及通过SGD学习embedding)

  2. 邻居抽样数量:下图B中邻居抽样数量递增,边际收益递减(F1),但计算时间也变大。 平衡F1和计算时间,将S1设为25。

  3. 聚合K跳内信息:在GraphSAGE, K=2 相比K=1 有10-15%的提升;但将K设置超过2,边际效果上只有0-5%的提升,但是计算时间却变大了10-100倍

分类F1指标

  • 整体效果如下图所示

  • 具体结论

  1. GraphSAGE相比baseline 效果大幅度提升

  2. GraphSAGE有监督版本比无监督效果好

  3. LSTM和pool的效果较好

  4. 尽管LSTM是为有序数据而不是无序集设计的,可是基于LSTM的聚合器显示了强大的性能

  5. 最后,能够看到无监督GraphSAGE的性能与彻底监督的版本相比具备较强的竞争力,这代表文中的框架能够在不进行特定于任务的微调(task-specific fine-tuning)的状况下实现强大的性能

不同聚合器之间的性能对比

  1. 实验中将设置六种不一样的实验,即(3个数据集)×(非监督vs.监督))

  2. 使用非参数Wilcoxon Signed-Rank检验来量化实验中不一样聚合器之间的差别,在适用的状况下报告T-statistic和p-value

  3. 从分类6的F1指标中,可知LSTM和pool的效果较好,没有显著差别

  4. GraphSAGE-LSTM比GraphSAGE-pool慢得多(≈2×),这可能使基于pooling的聚合器在整体上略占优势,即pooling的性能最好

代码实现

  1. 官方

    1. 官方网址:官方介绍

    2. 论文作者的开源实现:GraphSAGE-TensorFlow

    3. 开源实现的代码文件详解:详解

  2. 其它实现

    1. DeepCTR的TensorFlow实现:GraphSAGE-TensorFlow

参考资料

  1. GraphSAGE: GCN落地必读论文

  2. GNN 系列(三):GraphSAGE

  3. 【Graph Neural Network】GraphSAGE: 算法原理,实现和应用

  4. 网络表示学习: 淘宝推荐系统&&GraphSAGE

  5. [论文笔记]:GraphSAGE:Inductive Representation Learning on Large Graphs 论文详解 NIPS 2017

结尾

亲爱的读者朋友:感谢您在繁忙中驻足阅读本期内容!您的到来是对我们最大的支持❤️

正如古语所言:"当局者迷,旁观者清"。您独到的见解与客观评价,恰似一盏明灯💡,能帮助我们照亮内容盲区,让未来的创作更加贴近您的需求。

若此文给您带来启发或收获,不妨通过以下方式为彼此搭建一座桥梁: ✨ 点击右上角【点赞】图标,让好内容被更多人看见 ✨ 滑动屏幕【收藏】本篇,便于随时查阅回味 ✨ 在评论区留下您的真知灼见,让我们共同碰撞思维的火花

我始终秉持匠心精神,以键盘为犁铧深耕知识沃土💻,用每一次敲击传递专业价值,不断优化内容呈现形式,力求为您打造沉浸式的阅读盛宴📚。

有任何疑问或建议?评论区就是我们的连心桥!您的每一条留言我都将认真研读,并在24小时内回复解答📝。

愿我们携手同行,在知识的雨林中茁壮成长🌳,共享思想绽放的甘甜果实。下期相遇时,期待看到您智慧的评论与闪亮的点赞身影✨!

万分感谢🙏🙏您的点赞👍👍、收藏⭐🌟、评论💬🗯️、关注❤️💚~


 自我介绍:一线互联网大厂资深算法研发(工作6年+),4年以上招聘面试官经验(一二面面试官,面试候选人400+),深谙岗位专业知识、技能雷达图,已累计辅导15+求职者顺利入职大中型互联网公司。熟练掌握大模型、NLP、搜索、推荐、数据挖掘算法和优化,提供面试辅导、专业知识入门到进阶辅导等定制化需求等服务,助力您顺利完成学习和求职之旅(有需要者可私信联系)

友友们,自己的知乎账号为“快乐星球”,定期更新技术文章,敬请关注!  

http://www.dtcms.com/a/263646.html

相关文章:

  • CyclicBarrier(同步屏障)是什么?它的原理和用法是什么?
  • 新手向:从零开始Node.js超详细安装、配置与使用指南
  • Embeddings模型
  • 微服务介绍
  • Unity进阶课程【六】Android、ios、Pad 终端设备打包局域网IP调试、USB调试、性能检测、控制台打印日志等、C#
  • 【RTSP从零实践】4、使用RTP协议封装并传输AAC
  • 学习threejs,使用自定义GLSL 着色器,生成艺术作品
  • 电机参数测量
  • 自由学习记录(66)
  • JT808教程:消息的结构
  • react中在Antd3.x版本中 Select框在单选时 选中框的高度调整
  • Qt 实现Opencv功能模块切换界面功能
  • 【算法】动态规划:python实现 1
  • TensorFlow内核剖析:分布式TensorFlow架构解析与实战指南
  • mini-electron使用方法
  • 内部类与Lambda的衍生关系(了解学习内部类,Lambda一篇即可)
  • C# WPF + Helix Toolkit 实战:用两种方式打造“六面异色立方体”
  • QNN SDK学习笔记
  • 二十八、【环境管理篇】灵活应对:多测试环境配置与切换
  • python开发|yaml用法知识介绍
  • STM32F4操作内部FLASH简洁版
  • 【代码审计】安全审核常见漏洞修复策略
  • 位运算经典题解
  • 启用不安全的HTTP方法
  • 图像处理专业书籍以及网络资源总结
  • Java编程之状态模式
  • 《UE5_C++多人TPS完整教程》学习笔记40 ——《P41 装备(武器)姿势(Equipped Pose)》
  • 基于Socketserver+ThreadPoolExecutor+Thread构造的TCP网络实时通信程序
  • mac重复文件清理,摄影师同款清理方案
  • flv.js视频/直播流测试demo