当前位置: 首页 > news >正文

LEARNING ON LARGE-SCALE TEXT-ATTRIBUTED GRAPHS VIA VARIATIONAL INFERENCE

LEARNING ON LARGE-SCALE TEXT-ATTRIBUTED GRAPHS VIA VARIATIONAL INFERENCE

ICLR23

推荐指数:#paper/⭐⭐#​

作者的写作手法感觉就是有点把模型往数学化的方式去写,内容其实就相当于一个LM与GCN互相提供伪标签监督。利用KL散度来优化有点意思

动机

关键挑战

  • 传统方法同时训练LLMs和GNNs时,计算复杂度高,难以扩展到大规模图数据。

方法创新:GLEM框架

  1. 交替训练机制

    • E步:固定GNN参数,GNN预测的伪标签与观察到的标签一起用于LM训练,使其学习与图拓扑一致的文本表示。
    • M步:固定LLM参数,LM提供文本嵌入和伪标签给GNN,提升结构建模能力。
  2. 模块解耦:避免同时训练两个大模型,显著降低计算开销。

  3. 互增强机制:LLM和GNN通过交替更新互相提供监督信号(伪标签),逐步提升整体性能。

模型细节

image

伪似然变分框架

这是一个基于伪似然变分框架的方法,该方法用于模型设计,提供了一种原则性和灵活性的公式化方式。具体来说,这个方法的目标是最大化观测到的节点标签的对数似然函数,即 p ( y L ∣ s V , A ) p(y_L | s_V, A) p(yLsV,A),其中 y L y_L yL 是已标记节点的标签集合, s V s_V sV 是所有节点的文本特征 A A A​ 是图的邻接矩阵。

直接优化这个函数通常是困难的,因为存在未观测到的节点标签 y U y_U yU。为了解决这个问题,该框架不直接优化对数似然函数,而是优化一个称为证据下界(Evidence Lower Bound,ELBO)的量。ELBO 的表达式如下:

l o g p ( y L ∣ s V , A ) ≥ E q ( y U ∣ s U ) [ log ⁡ p ( y L , y U ∣ s V , A ) − log ⁡ q ( y U ∣ s U ) ] \\log p(y_L | s_V, A) \geq \mathbb{E}_{q(y_U | s_U)}[\log p(y_L, y_U | s_V, A) - \log q(y_U | s_U)] logp(yLsV,A)Eq(yUsU)[logp(yL,yUsV,A)logq(yUsU)]

这里, q ( y U ∣ s U ) q(y_U | s_U) q(yUsU)是一个变分分布,上述不等式对任何 q q q都成立。ELBO 可以通过交替优化分布 q q q(即 E-step)和分布 p p p(即 M-step)来优化。

  • E-step(期望步) :优化变分分布 q q q,目的是最小化 q q q 和后验分布 p ( y U ∣ s V , A , y L ) p(y_U | s_V, A, y_L) p(yUsV,A,yL) 之间的 Kullback-Leibler (KL) 散度,从而收紧上述下界。
  • M-step(最大化步) :优化目标分布 ( p ),以最大化伪似然函数:

E q ( y U ∣ s U ) [ log ⁡ p ( y L , y U ∣ s V , A ) ] ≈ E q ( y U ∣ s U ) [ ∑ n ∈ V log ⁡ p ( y n ∣ s V , A , y V ∖ n ) ] \mathbb{E}_{q(y_U | s_U)}[\log p(y_{L}, y_{U} | s_{V}, A)] \approx \mathbb{E}_{q(y_U | s_U)}\left[\sum_{n \in V}\log p(y_{n} | s_{V}, A, y_{V \setminus n})\right] Eq(yUsU)[logp(yL,yUsV,A)]Eq(yUsU)[nVlogp(ynsV,A,yVn)]

这个过程通过交替执行 E-step 和 M-step 来实现,直到收敛。这种方法允许模型在不需要直接处理未观测标签的情况下,有效地学习节点表示。简而言之,这种方法通过优化一个下界来间接优化对数似然函数,这个下界可以通过交替优化两个分布来逐步提高,从而使得模型能够更好地处理未观测数据,并提高学习效果。

具体的两个模型的介绍

在这部分内容中,文章详细阐述了GLEM方法中使用的两种分布—— q q q p p p的参数化过程,以及它们如何用于节点标签分布的建模和优化。

分布 q q q的参数化

分布 q q q的目标是利用文本信息 s U s_U sU 来定义节点标签分布,这相当于一个语言模型(LM)。在 GLEM 中,采用均场形式(mean-field form),假设不同节点的标签是独立的,每个节点的标签只依赖于它自己的文本信息。这导致了以下形式的分解:
q θ ( y U ∣ s U ) = ∏ n ∈ U q θ ( y n ∣ s n ) . q_{\theta}(y_{U}|s_{U}) = \prod_{n \in U}q_{\theta}(y_{n}|s_{n}). qθ(yUsU)=nUqθ(ynsn).这里, θ \theta θ是模型参数​** U U U是未标记节点的集合。在这里,每个项 q θ ( y n ∣ s n ) q_{\theta}(y_{n}|s_{n}) qθ(ynsn)可以通过基于 Transformer 的语言模型 q θ q_{\theta} qθ来建模,该模型通过注意力机制有效地模拟细粒度的标记交互。**

分布 p p p的参数化

分布 p p p定义了一个条件分布 p ϕ ( y n ∣ s V , A , y V \ n ) p_{\phi}(y_{n}|s_{V}, A, y_{V\backslash n}) pϕ(ynsV,A,yV\n),旨在利用节点特征 s V s_{V} sV、图结构 A A A和其他节点标签 y V \ n y_{V\backslash n} yV\n来表征每个节点的标签分布。因此, p ϕ ( y n ∣ s V , A , y V \ n ) p_{\phi}(y_{n}|s_{V}, A, y_{V\backslash n}) pϕ(ynsV,A,yV\n)被建模为一个由 ϕ \phi ϕ参数化的 G N N p ϕ GNN p_{\phi} GNNpϕ,以有效地模拟节点间的结构交互。

由于节点文本 s V s_V sV是离散变量,不能直接被 GNN 使用,因此在实践中,我们首先使用 L M q θ LM q_{\theta} LMqθ 对节点文本进行编码,然后使用获得的嵌入作为 G N N p ϕ GNN p_{}\phi GNNpϕ的节点文本的替代。

E-STEP: LM OPTIMIZATION

  1. 目标:在E-step中,固定GNN,更新LM以最大化证据下界。这样做的目的是将不同节点之间的全局语义相关性提取到LM中。最大化关于 LM 的证据下限等同于最小化后验分布和变分分布之间的 KL 散度
  2. 优化方法:直接优化KL散度是困难的,因为KL散度依赖于难以处理的变分分布的熵。为了克服这个挑战,作者采用了wake-sleep算法来最小化反向KL散度,从而得到一个更易于处理的目标函数。
  3. 目标函数:目标函数是关于LM q θ q_\theta qθ的,目的是最大化这个函数。这个函数的形式是:(KL的前面,第一项相当于GNN,第二项相当于LM)
    − K L ( p ϕ ( y U ∣ s V , A , y L ) ∣ ∣ q θ ( y U ∣ s U ) ) = ∑ n ∈ U E p ϕ ( y n ∣ s V , A , y L ) [ log ⁡ q θ ( y n ∣ s n ) ] + const -KL(p_\phi(y_U|s_V, A, y_L)||q_\theta(y_U|s_U)) = \sum_{n \in U} \mathbb{E}_{p_\phi(y_n|s_V, A, y_L)}[\log q_\theta(y_n|s_n)] + \text{const} KL(pϕ(yUsV,A,yL)∣∣qθ(yUsU))=nUEpϕ(ynsV,A,yL)[logqθ(ynsn)]+const
    这个目标函数更容易处理,因为我们不再需要考虑 q θ ( y U ∣ s U ) q_\theta(y_U|s_U) qθ(yUsU)的熵。
  4. 分布计算:唯一的困难在于计算分布 p ϕ ( y n ∣ s V , A , y L ) p_\phi(y_n|s_V, A, y_L) pϕ(ynsV,A,yL)。在原始GNN中,我们基于周围节点标签 y V \ n y_{V\backslash n} yV\n来预测节点 ( n ) 的标签分布。然而,在 p ϕ ( y n ∣ s V , A , y L ) p_\phi(y_n|s_V, A, y_L) pϕ(ynsV,A,yL)中,我们只基于观察到的节点标签 y L y_L yL,其他节点的标签是未指定的,因此我们不能直接用GNN计算这个分布。
  5. 解决方案:为了解决这个问题,作者提出用LM预测的伪标签来标注图中所有未标记的节点,从而可以近似分布:
    p ϕ ( y n ∣ s V , A , y L ) ≈ p ϕ ( y n ∣ s V , A , y L , y ^ U \ n ) p_\phi(y_n|s_V, A, y_L) \approx p_\phi(y_n|s_V, A, y_L, \hat{y}_{U\backslash n}) pϕ(ynsV,A,yL)pϕ(ynsV,A,yL,y^U\n)
    其中 y ^ U \ n \hat{y}_{U\backslash n} y^U\n是未标记节点的伪标签集合。
  6. 最终目标函数:结合上述目标函数和标记节点,得到训练LM的最终目标函数:
    O ( q ) = α ∑ n ∈ U E p ( y n ∣ s V , A , y L , y ^ U \ n ) [ log ⁡ q ( y n ∣ s n ) ] + ( 1 − α ) ∑ n ∈ L log ⁡ q ( y n ∣ s n ) \mathcal{O}(q) = \alpha \sum_{n \in U} \mathbb{E}_{p(y_n|s_V, A, y_L, \hat{y}_{U\backslash n})}[\log q(y_n|s_n)] + (1-\alpha) \sum_{n \in L} \log q(y_n|s_n) O(q)=αnUEp(ynsV,A,yL,y^U\n)[logq(ynsn)]+(1α)nLlogq(ynsn)
    其中 α \alpha α是一个超参数。直观上,第二项是一个监督目标,使用给定的标记节点进行训练。同时,第一项可以看作是一个知识蒸馏过程,通过强制LM基于邻域文本信息预测标签分布来训练LM

M-STEP: GNN OPTIMIZATION

目标:在GNN阶段,目标是固定语言模型 q θ q_\theta qθ并优化图神经网络 p ϕ p_\phi pϕ​以最大化伪似然(pseudo-likelihood)。

  1. 方法

    • 使用语言模型为所有节点生成节点表示 h V h_V hV​,并将这些表示作为文本特征输入到图神经网络中进行消息传递。
    • 利用语言模型 q θ q_\theta qθ为每个未标记节点 n ∈ U n \in U nU预测一个伪标签 y ^ n \hat{y}_n y^n并将所有伪标签 { y ^ n } n ∈ U \{\hat{y}_n\}_{n \in U} {y^n}nU组合成 y ^ U \hat{y}_U y^U​。
  2. 伪似然重写:结合节点表示和LM q θ q_\theta qθ​的伪标签,伪似然可以重写为:
    O ( ϕ ) = β ∑ n ∈ U log ⁡ p ϕ ( y ^ n ∣ s V , A , y L , y ^ U ∖ n ) + ( 1 − β ) ∑ n ∈ L log ⁡ p ϕ ( y n ∣ s V , A , y L ∖ n , y ^ U ) \mathcal{O}(\phi) = \beta \sum_{n \in U}\log p_{\phi}(\hat{y}_{n} | s_{V}, A, y_{L} , \hat{y}_{U \setminus n}) + (1 - \beta) \sum_{n \in L}\log p_{\phi}(y_{n} | s_{V} , A, y_{L \setminus n}, \hat{y}_{U}) O(ϕ)=βnUlogpϕ(y^nsV,A,yL,y^Un)+(1β)nLlogpϕ(ynsV,A,yLn,y^U)
    其中, β \beta β是一个超参数,用于平衡两个项的权重。第一项:可以看作是一个知识蒸馏过程,通过所有伪标签将LM捕获的知识注入到GNN中。第二项:是一个监督损失,使用观察到的节点标签进行模型训练。

一旦训练完成,E-step中的LM(记为GLEM-LM)和M-step中的GNN(记为GLEM-GNN)都可以用来进行节点标签预测。

实验

image

  • 结果

    • GLEM-GNN:在所有三个数据集上取得了新的最佳性能,证明了其在节点分类任务中的有效性。
    • GLEM-LM:通过结合图结构信息,显著提升了语言模型的性能。
    • 可扩展性:GLEM 能够适应大型语言模型(如 DeBERTa-large),并且在效率和性能之间取得了良好的平衡。

相关文章:

  • Go语言中使用viper绑定结构体和yaml文件信息时,标签的使用
  • NIO-Reactor模型梳理与demo实现
  • Linux 第三次脚本作业
  • 如何使用智能指针来管理动态分配的内存
  • 函数中的形参和实参(吐槽)
  • R 语言科研绘图 --- 散点图-汇总
  • 记录 idea 启动 tomcat 控制台输出乱码问题解决
  • 嵌入式Linux内核底层调试技术Kprobes
  • N32G003查看设备重启原因
  • 洛谷P1135多题解
  • Pytorch使用手册-音频数据增强(专题二十)
  • 显卡(Graphics Processing Unit,GPU)架构详细解读
  • Linux 第二次脚本作业
  • [设计模式] Builder 建造者模式
  • [Windows] 全国油价实时查询,可具体到城市
  • TCP/UDP调试工具推荐:Socket通信图解教程
  • vscode settings(二):文件资源管理器编辑功能主题快捷键
  • 字符串中字母的大小写转换
  • 【模板】Linux中cmake使用编译c++程序
  • 【JavaEE进阶】Spring DI
  • 中国强镇密码丨洪泽湖畔的蒋坝,如何打破古镇刻板印象
  • 中央宣传部、全国总工会联合发布2025年“最美职工”先进事迹
  • 体坛联播|欧冠半决赛阿森纳主场不敌巴黎,北京男篮险胜山西
  • 中国人保一季度业绩“分化”:财险净利增超92%,寿险增收不增利
  • 企业取消“大小周”引热议,半月谈:不能将显性加班变为隐性加班
  • 金融街:去年净亏损约110亿元,今年努力实现经营稳健和财务安全