当前位置: 首页 > news >正文

PyTorch 激活函数

激活函数是神经网络中至关重要的组成部分,它们为网络引入了非线性特性,使得神经网络能够学习复杂模式。PyTorch 提供了多种常用的激活函数实现。

常用激活函数

1. ReLU (Rectified Linear Unit)

数学表达式:

PyTorch实现:

torch.nn.ReLU(inplace=False)

特点:

  • 计算简单高效

  • 解决梯度消失问题(正区间)

  • 可能导致"神经元死亡"(负区间梯度为0),ReLU 在输入为负时输出恒为 0,导致反向传播中梯度消失,相关权重无法更新‌14。若神经元长期处于负输入状态,则会永久“死亡”,失去学习能力‌。

示例:

relu = nn.ReLU()
input = torch.tensor([-1.0, 0.0, 1.0, 2.0])
output = relu(input)  # tensor([0., 0., 1., 2.])

2. LeakyReLU

数学表达式:

PyTorch实现:

torch.nn.LeakyReLU(negative_slope=0.01, inplace=False)

特点:

  • 解决了ReLU的"神经元死亡"问题,通过引入负区间的微小斜率(如 torch.nn.LeakyReLU(negative_slope=0.01)),保留负输入的梯度传播,避免神经元死亡‌。

  • negative_slope通常设为0.01

示例

leaky_relu = nn.LeakyReLU(negative_slope=0.1)
input = torch.tensor([-1.0, 0.0, 1.0, 2.0])
output = leaky_relu(input)  # tensor([-0.1000, 0.0000, 1.0000, 2.0000])

3. Sigmoid

数学表达式:

 PyTorch实现:

torch.nn.Sigmoid()

特点:

  • 输出范围(0,1),适合二分类问题

  • 容易出现梯度消失问题

  • 输出不以0为中心

示例:

sigmoid = nn.Sigmoid()
input = torch.tensor([-1.0, 0.0, 1.0, 2.0])
output = sigmoid(input)  # tensor([0.2689, 0.5000, 0.7311, 0.8808])

 

4. Tanh (Hyperbolic Tangent)

数学表达式:

PyTorch实现

torch.nn.Tanh()

特点:

  • 输出范围(-1,1),以0为中心

  • 比sigmoid梯度更强

  • 仍存在梯度消失问题

示例:

tanh = nn.Tanh()
input = torch.tensor([-1.0, 0.0, 1.0, 2.0])
output = tanh(input)  # tensor([-0.7616, 0.0000, 0.7616, 0.9640])

5. Softmax

数学表达式:

PyTorch实现:

torch.nn.Softmax(dim=None)

特点:

  • 输出为概率分布(和为1)

  • 常用于多分类问题的输出层

  • dim参数指定计算维度

示例:

softmax = nn.Softmax(dim=1)
input = torch.tensor([[1.0, 2.0, 3.0]])
output = softmax(input)  # tensor([[0.0900, 0.2447, 0.6652]])

其他激活函数

6. ELU (Exponential Linear Unit)

torch.nn.ELU(alpha=1.0, inplace=False)

7. GELU (Gaussian Error Linear Unit) 

torch.nn.GELU()

8. Swish

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

选择指南

  1. 隐藏层:通常首选ReLU及其变体(LeakyReLU、ELU等)

  2. 二分类输出层:Sigmoid

  3. 多分类输出层:Softmax

  4. 需要负输出的情况:Tanh或LeakyReLU

  5. Transformer模型:常用GELU

自定义激活函数

PyTorch可以轻松实现自定义激活函数:

class CustomActivation(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return torch.where(x > 0, x, torch.exp(x) - 1)

注意事项

  1. 梯度消失/爆炸问题

  2. 死亡神经元问题(特别是ReLU)

  3. 计算效率考虑

  4. 初始化方法应与激活函数匹配

http://www.dtcms.com/a/107887.html

相关文章:

  • PyQt5和OpenCV车牌识别系统
  • Java基础 4.2
  • Mysql 在什么样的情况下会产生死锁?
  • Python爬虫第2节-网页基础和爬虫基本原理
  • 2.Linux的权限理解
  • mysql docker容器启动遇到的问题整理
  • 华为面试,机器学习深度学习知识点:
  • Windows C++ 排查死锁
  • MIT6.S081 - Lab6 Copy-on-Write(写时复制)
  • 模拟集成电路设计与仿真 : Mismatch
  • 数据库 第一章 MYSQL基础(4)
  • 《汽车噪声控制》课程作业
  • 英飞凌高信噪比MEMS麦克风驱动人工智能交互
  • Pandas基础及series对象
  • Token是什么?
  • 时序数据库 InfluxDB(六)
  • Python爬虫第一战(爬取优美图库网页图片)
  • *快排延伸-自省排序
  • conda activate激活环境失败问题
  • 《雷神之锤 III 竞技场》快速求平方根倒数的计算探究
  • conda 激活环境vscode的Bash窗口
  • 数据清洗的具体内容
  • 【Linux】手动部署并测试内网穿透
  • Python基础语法 - 判断语句
  • ffmpeg命令整理
  • 从零开始学习Slam|ICP原理与应用
  • Sentinel实战(三)、流控规则之流控效果及流控小结
  • OpenIPC开源FPV之Adaptive-Link新版本算法v0.60.0
  • 强大而易用的JSON在线处理工具
  • python网络爬虫开发实战之Ajax数据提取