当前位置: 首页 > news >正文

深度学习-神经网络(上篇)

一、神经网络概述

1. 基本概念

  • ​人工神经网络 (ANN)​​:一种模仿生物神经网络结构与功能的计算模型。
  • ​生物神经元机理​​:树突接收输入信号,细胞核处理并聚集电荷,达到电位阈值后通过轴突输出电信号。
  • ​人工神经元机理​​:对多个输入进行​​加权求和​​,再通过一个​​激活函数​​产生输出。
    输出 = 激活函数(Σ(输入 * 权重) + 偏置)

2. 网络结构

一个典型的全连接神经网络包含以下层次:

  • ​输入层​​:接收原始数据。
  • ​隐藏层​​:介于输入和输出层之间,进行特征变换。可以有多个。
  • ​输出层​​:产生最终的预测结果。

​结构特点​​:

  • 信息单向传播(前向传播)。
  • 同一层的神经元之间无连接。
  • 第N层的每个神经元与第N-1层的所有神经元全连接。
  • 每个连接都有其权重(w)和偏置(b)。

3. 深度学习与机器学习的关系

  • ​深度学习是机器学习的一个子集​​。
  • 主要区别在于​​特征工程​​:
    • ​传统机器学习​​:严重依赖人工特征工程。
    • ​深度学习​​:模型通过多层神经网络​​自动学习​​数据的层次化特征表示。

二、激活函数

​作用​​:为网络引入​​非线性因素​​,使得神经网络能够拟合任意复杂的函数。若无激活函数,多层网络等价于一个线性模型。

常见激活函数对比

激活函数公式输出范围特点优点缺点适用场景
​Sigmoid​f(x) = 1 / (1 + e⁻ˣ)(0, 1)S型曲线输出可视为概率1. 易产生​​梯度消失​​(导数范围(0, 0.25)
2. 非零中心
3. 计算含指数,较慢
​输出层​​(二分类)
​Tanh​f(x) = (eˣ - e⁻ˣ) / (eˣ + e⁻ˣ)(-1, 1)放大并平移的Sigmoid​零中心​​,收敛速度比Sigmoid快仍存在​​梯度消失​​问题​隐藏层​
​ReLU​f(x) = max(0, x)[0, +∞)简单阈值过滤1. 计算高效,缓解梯度消失(正区间)
2. 带来网络稀疏性
​神经元死亡​​(负区间梯度为0)​隐藏层​​(首选)
​Softmax​f(xᵢ) = eˣⁱ / Σⱼeˣʲ(0, 1) 且和为1多分类Sigmoid推广将输出归一化为​​概率分布​-​输出层​​(多分类)

激活函数选择指南

  • ​隐藏层​​:
    • ​优先使用 ReLU​​,注意学习率设置以防“死亡神经元”。
    • ReLU效果不佳时,可尝试 Leaky ReLU 等变体。
    • 少用 Sigmoid,可尝试 Tanh。
  • ​输出层​​:
    • ​二分类​​:Sigmoid
    • ​多分类​​:Softmax
    • ​回归​​:恒等函数(即无激活函数)

三、参数初始化方法

权重初始化的好坏直接影响模型的收敛速度和最终性能。

常见初始化方法

  1. ​简单初始化​​:

    • ​均匀分布初始化​​:torch.nn.init.uniform_()
    • ​正态分布初始化​​:torch.nn.init.normal_()
    • ​全零初始化​​:torch.nn.init.zeros_() -> ​​导致神经元对称失效,禁止使用​​。
    • ​全一初始化​​:torch.nn.init.ones_() -> 效果差,一般不使用。
    • ​固定值初始化​​:torch.nn.init.constant_()
  2. ​高级初始化(推荐)​​:

    • ​Kaiming (He) 初始化​​:为解决ReLU激活函数设计的初始化方法。
      • 正态分布:std = sqrt(2 / fan_in)
      • 均匀分布:limit = sqrt(6 / fan_in)
      • fan_in:该层输入神经元的个数。
      • PyTorch API: torch.nn.init.kaiming_normal_(), torch.nn.init.kaiming_uniform_()
    • ​Xavier (Glorot) 初始化​​:为解决Sigmoid/Tanh等S型激活函数设计。
      • 正态分布:std = sqrt(2 / (fan_in + fan_out))
      • 均匀分布:limit = sqrt(6 / (fan_in + fan_out))
      • fan_in:输入神经元个数,fan_out:输出神经元个数。
      • PyTorch API: torch.nn.init.xavier_normal_(), torch.nn.init.xavier_uniform_()

初始化方法选择

  • ​通常优先使用 Kaiming 或 Xavier 初始化​​。
  • PyTorch 中许多层已有合理的默认初始化,但自定义层时需手动初始化。

四、神经网络的搭建与参数计算

1. 模型搭建步骤(PyTorch)

在PyTorch中,通过继承 nn.Module 类来定义模型。

import torch.nn as nnclass MyModel(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(MyModel, self).__init__() # 必须调用父类初始化# 定义网络层self.linear1 = nn.Linear(input_dim, hidden_dim)self.linear2 = nn.Linear(hidden_dim, output_dim)# 初始化权重nn.init.xavier_normal_(self.linear1.weight)nn.init.kaiming_normal_(self.linear2.weight)def forward(self, x):# 定义前向传播路径x = torch.sigmoid(self.linear1(x))x = self.linear2(x) # 输出层通常不加激活函数,损失函数中会集成return x

2. 参数量的计算

神经网络的参数量主要指​​权重(w)​​ 和​​偏置(b)​​ 的数量。

  • 对于一个全连接层:参数总量 = (输入特征数 + 1) * 输出特征数
    • +1 代表偏置项。
  • ​示例​​:一个输入为3个特征,输出为2个神经元的层,参数量为 (3 + 1) * 2 = 8
  • 可以使用 torchsummary 库的 summary(model, input_size) 函数自动计算和打印模型总参数量和各层细节。

3. 输入输出形状

  • ​输入张量形状​​:[batch_size, in_features]
  • ​输出张量形状​​:[batch_size, out_features]
  • 训练时使用 DataLoaderbatch_size 组织数据。

4. 神经网络的优缺点

  • ​优点​​:
    • 精度高,在诸多领域性能领先。
    • 能够近似任意复杂函数。
    • 社区成熟,有大量框架和库支持。
  • ​缺点​​:
    • ​黑箱模型​​,解释性差。
    • 训练时间长,计算资源消耗大。
    • 网络结构复杂,需要大量调参。
    • 在小数据集上容易​​过拟合​​。


http://www.dtcms.com/a/389183.html

相关文章:

  • 【脑电分析系列】第18篇:传统机器学习在EEG中的应用 — SVM、LDA、随机森林等分类器
  • 理解长短期记忆神经网络(LSTM)
  • Kurt-Blender零基础教程:第2章:建模篇——第1节:点线面的选择与控制与十大建模操作
  • 鸿蒙5.0应用开发——V2装饰器@Monitor的使用
  • 八、Java-XML
  • 计算机在医疗领域应用的独特技术问题分析
  • HTB Intentions writeup(SQL二次注入也是注入)
  • 第一章 预训练:让模型“博闻强识”
  • 【数组】求两个匀速运动质点的相交或最小距离
  • 新手向:Python爬虫原理详解,从零开始的网络数据采集指南
  • OKZOO进军HealthFi:承接AIoT,引领Health-to-Earn
  • Halcon 相机标定
  • 腾讯混元发布集成翻译模型Hunyuan-MT-Chimera-7B,已开放体验
  • mybatis-plus扩展
  • 从x.ai到VSCode:一个AI编程助手的意外之旅
  • SQLite vs MySQL:核心SQL语法差异全面解析
  • 【每日算法】两数相加 LeetCode
  • ActiveMQ底层原理与性能优化
  • Ceph IO流程分段上传(1)——InitMultipart
  • 大数据毕业设计选题推荐-基于大数据的农作物产量数据分析与可视化系统-Hadoop-Spark-数据可视化-BigData
  • 【回归之作】学校实训作业:Day04面向对象思想编程
  • Ubuntu20.04或者Ubuntu24.04 TypeC-连接屏幕不显示问题
  • 【SQLSERVER】SQL Server 表导出与导入
  • postgresql和mongodb谁的地位更高
  • RK3588+复旦微JFM7K325T工业控制解决方案
  • RabbitMQ全方位解析
  • 云望无人机图传原理,无人机图传方式哪种好
  • 无人机50公里遥控模块技术要点与难点
  • 【三维重建】Octree-GS:基于LOD的3DGS实时渲染(TPAMI2025)
  • 《深度拆解3D开放世界游戏中角色攀爬系统与地形碰撞网格动态适配的穿透卡顿复合故障》