当前位置: 首页 > news >正文

Gumbel-Softmax函数

Gumbel-Softmax函数

背景动机

在许多任务中,我们需要从一个离散分布中采样,例如从one-hot编码中选出某个类别:

  • 但是离散采样操作是不可导的,这使得无法通过反向传播更新参数
  • Gumbel-Softmax提供了一种近似采样的方法,它是可微的,因此可以端到端训练神经网络

什么任务需要离散分布采样

  • 神经架构搜索:在给定多个可能的网络组件中选择一个子结构
  • 强化学习:代理在每一步需要从有限的动作空间中选择一个动作,如走上/下/走/右

Gumbel-Max Trick

如果我们希望从一个离散的概率分布z=[z1,z2....,zK]z=[z_1,z_2....,z_K]z=[z1,z2....,zK]中采样一个类别,可以通过以下方式实现:
y=argmax[log(zi)+gi] y = argmax[log(z_i)+g_i] y=argmax[log(zi)+gi]
其中gi=Gumbel(0,1)g_i=Gumbel(0,1)gi=Gumbel(0,1),这个过程称为Gumbel-Max trick,可以视为在logits上加上噪声后取最大值。
然而,argmax操作显然不可导
为了使Gumbel-Max变为可导,我们将argmax用softmax来近似
yi=exp((log(zi)+gi)/τ)∑j=1Kexp((log(zi)+gi)/τ) y_i = \frac{exp((log(z_i)+g_i)/\tau)}{\sum_{j=1}^Kexp((log(z_i)+g_i)/\tau)} yi=j=1Kexp((log(zi)+gi)/τ)exp((log(zi)+gi)/τ)

  • gi=Gumbel(0,1)是Gumbel噪声g_i = Gumbel(0,1)是Gumbel噪声gi=Gumbel(0,1)Gumbel噪声
  • τ>0\tau>0τ>0是温度参数
    τ→∞\tau\rightarrow\inftyτ, Gumbel-Softmax输出趋近于平均分布(更平滑)
    τ→0\tau\rightarrow0τ0,输出趋近于one-hot(更像真实采样,但梯度不稳定)
    因此,训练时通常采用:高温度开始,逐渐降低温度
    在这里插入图片描述

为什么说Gumbel-Softmax模拟了采样行为?

从一个离散概率分布z=[0.1,0.7,0.2]z=[0.1,0.7,0.2]z=[0.1,0.7,0.2]中采样,指的是:根据概率值,随机选择一个类别(one-hot)作为结果。
有70%概率选择第2类: [0,1,0]
有10%概率选择第1类:[1,0,0]
但是采样过程不可导。
Gumbel-Max Trick = 真实采样
y=argmax[log(zi)+gi] y = argmax[log(z_i)+g_i] y=argmax[log(zi)+gi]

举个例子:

  1. logits = [2.0, 1.0, 0.1]
    Softmax 后输出:
[0.57, 0.31, 0.12]  # 每次都一样,不是真采样
  1. Gumbel-Softmax(多次运行)
    每次加上 Gumbel 噪声再 softmax,例如:
Sample 1: [0.97, 0.02, 0.01]
Sample 2: [0.03, 0.91, 0.06]
Sample 3: [0.05, 0.10, 0.85]

这些近似 one-hot 输出,就模拟了多次“真实采样”的过程。
为什么要使用log函数?
压缩大值,使噪声占比大

Gumbel 分布 (耿贝尔分布)

Gumbel 分布是一种极值分布,用于建模“最大值”的概率分布。
📌 标准 Gumbel(0,1) 分布的定义:
一个随机变量 g 服从标准 Gumbel 分布,当其概率密度函数(PDF)为:
f(g)=exp(−(g+e−g)) f(g) = exp(-(g+e^{-g})) f(g)=exp((g+eg))
如何采样Gumbel噪声
g=−log(−log(U)) g = - log(-log(U)) g=log(log(U))

def sample_gumbel(shape, eps=1e-20):U = torch.rand(shape)return -torch.log(-torch.log(U + eps) + eps)

总结:Gumbel-Softmax在Softmax的基础上增加了噪声扰动性,从而达到离散分布采样的作用

http://www.dtcms.com/a/313138.html

相关文章:

  • AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年8月3日第155弹
  • 数据与信息的边界:非法获取计算机信息系统数据罪的司法困境与出路
  • 【十九、Javaweb-day19-Linux概述】
  • python---可变类型、不可变类型
  • Pytorch 报错-probability tensor contains either ‘inf‘, ‘nan‘ or element < 0 解决方案
  • Arrays.asList() add方法报错java.lang.UnsupportedOperationException
  • 8月3日星期日今日早报简报微语报早读
  • 多线程(四) ~ wait,join,sleep及单例与工厂模式
  • 图像识别区分指定物品与其他物体
  • 【华为机试】815. 公交路线
  • NumPy库学习(三):numpy在人工智能数据处理的具体应用及方法
  • 机器学习sklearn:支持向量机svm
  • Vue3 其它Composition API
  • Linux网络编程 --- 多路转接select
  • 推送本地项目到Gitee远程仓库
  • Selenium Web 自动化
  • 优选算法 力扣 202.快乐数 快慢双指针 解决带环问题 C++解题思路 每日一题
  • ThinkPHP5x,struts2等框架靶场复现
  • Coin Combinations II(Dynamic Programming)
  • LLM - AI大模型应用集成协议三件套 MCP、A2A与AG-UI
  • 用 Eland 在 Elasticsearch Serverless 部署 Learning-to-Rank 排序模型
  • 数据,正在成为AI大模型最后的护城河
  • leetcode 2106. 摘水果 困难
  • Rust 同步方式访问 REST API 的完整指南
  • 道格拉斯-普克算法 - 把一堆复杂的线条变得简单,同时尽量保持原来的样子
  • python---赋值、浅拷贝、深拷贝
  • 【C 学习】03-你的第一个C程序
  • 上位机知识篇---脚本文件
  • Linux环境下使用Docker搭建多服务环境
  • Corrosion2靶场