当前位置: 首页 > news >正文

字节面试手撕题:神经网络模型损失函数与梯度推导

题目

设神经网络模型y = \sigma(w_0 + w_1 x),其中\sigma是Sigmoid函数:

(1)写出Sigmoid函数的表达式并求导,对于简单的二分类问题,损失函数采用什么?

(2)采用(1)中的损失函数,推导反向传播中损失函数对神经网络参数w_0,w_1的梯度表达式;

(3)对于单个样本,写出神经网络参数w_0,w_1更新的表达式;对于多个样本(如batch_size=10)呢?

(4)若训练中发现模型损失为NaN,问题可能出在哪里,有什么排查方法?

解答

(1)Sigmoid函数是一种常用的激活函数,用于将输入值映射到(0,1)区间,其表达式为:

\sigma(z) = \frac{1}{1 + e^{-z}}

Sigmoid函数的导数可以通过链式法则求得。令 y = \sigma(z),则导数如下:

\frac{d\sigma(z)}{dz} = \sigma(z) \cdot (1 - \sigma(z)) = y \cdot (1 - y)

推导过程:

设 \sigma(z) = \frac{1}{1 + e^{-z}},令 u = 1 + e^{-z},则 \sigma(z) = u^{-1}

\frac{d\sigma}{du} = -u^{-2} = -\frac{1}{(1 + e^{-z})^2},又\frac{du}{dz} = -e^{-z}

因此,\frac{d\sigma}{dz} = \frac{d\sigma}{du} \cdot \frac{du}{dz} = \left( -\frac{1}{(1 + e^{-z})^2} \right) \cdot (-e^{-z}) = \frac{e^{-z}}{(1 + e^{-z})^2}

注意到 \frac{e^{-z}}{(1 + e^{-z})^2} = \frac{1}{1 + e^{-z}} \cdot \frac{e^{-z}}{1 + e^{-z}} = \sigma(z) \cdot (1 - \sigma(z)),因为 \frac{e^{-z}}{1 + e^{-z}} = 1 - \sigma(z)

对于简单的二分类问题,损失函数采用二分类交叉熵损失(Binary Cross-Entropy Loss)。对于单个样本,损失函数 L 定义为:

L = - \left[ t \log(y) + (1 - t) \log(1 - y) \right]

其中:
t 是真实标签(取值为 0 或 1),
y 是模型的预测概率(本题中即 y = \sigma(w_0 + w_1 x))。

该损失函数衡量了预测概率 y 与真实标签 t 之间的差异,常用于逻辑回归和神经网络中的二分类任务。

(2)对于给定神经网络模型  y = \sigma(w_0 + w_1 x) ,其中\sigma(z) = \frac{1}{1 + e^{-z}}是 Sigmoid 函数,损失函数采用二分类交叉熵损失(对于单个样本):

L = - \left[ t \log(y) + (1 - t) \log(1 - y) \right]

其中 t 是真实标签(0 或 1), y 是预测值。

通过反向传播,损失函数对参数 w_0,w_1 的梯度表达式如下:

损失函数对 w_0 的梯度:\frac{\partial L}{\partial w_0} = y - t

损失函数对 w_1 的梯度:\frac{\partial L}{\partial w_1} = (y - t) \cdot x

推导说明:

推导过程使用链式法则:

1. 计算 \frac{\partial L}{\partial y}

\frac{\partial L}{\partial y} = -\frac{t}{y} + \frac{1-t}{1-y}

2. 计算 \frac{\partial y}{\partial z} (其中z=w_0+w_1x):利用 Sigmoid 函数的导数:

\frac{\partial y}{\partial z} = y(1 - y)

3. 计算 \frac{\partial L}{\partial z} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial z}

\frac{\partial L}{\partial z} = \left( -\frac{t}{y} + \frac{1-t}{1-y} \right) \cdot y(1-y) = y - t

4. 最后,计算对 w_0,w_1 的梯度:

\frac{\partial z}{\partial w_0} = 1,所以\frac{\partial L}{\partial w_0} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w_0} = (y - t) \cdot 1 = y - t

\frac{\partial z}{\partial w_1} = x,所以\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w_1} = (y - t) \cdot x

这些梯度表达式可用于梯度下降算法中更新参数 w_0,w_1。对于多个样本的情况,通常取所有样本梯度的平均值或求和。

(3)参数更新表达式

对于单个样本:

假设输入为 x,真实标签为 t,预测值为 y = \sigma(w_0 + w_1 x)。损失函数对参数的梯度为:

\frac{\partial L}{\partial w_0} = y - t

\frac{\partial L}{\partial w_1} = (y - t) \cdot x

参数更新表达式为:

w_0 \leftarrow w_0 - \eta \cdot (y - t)

w_1 \leftarrow w_1 - \eta \cdot (y - t) \cdot x

对于多个样本(batch_size=10):

假设有10个样本,输入为 x_i,真实标签为 t_i(i=1,2...10),预测值为 y_i = \sigma(w_0 + w_1 x_i)。损失函数对参数的梯度为平均梯度:

\frac{\partial L_{\text{batch}}}{\partial w_0} = \frac{1}{10} \sum_{i=1}^{10} (y_i - t_i)

\frac{\partial L_{\text{batch}}}{\partial w_1} = \frac{1}{10} \sum_{i=1}^{10} (y_i - t_i) \cdot x_i

参数更新表达式为:

w_0 \leftarrow w_0 - \eta \cdot \frac{1}{10} \sum_{i=1}^{10} (y_i - t_i)

w_1 \leftarrow w_1 - \eta \cdot \frac{1}{10} \sum_{i=1}^{10} (y_i - t_i) \cdot x_i

在实际实现中,通常使用批量梯度下降,即计算整个批次的梯度后更新参数。学习率 \eta 需要根据具体问题调整。

(4)模型训练中出现损失为NaN的原因及排查方法

可能的原因:

1. 学习率过高:过高的学习率可能导致参数更新步长过大,使得模型参数变得非常大,甚至溢出。

2. 梯度爆炸:梯度变得非常大,导致参数更新后出现非常大的数值,进而使后续计算出现NaN。

3. 数据问题:数据中存在缺失值(NaN)或无穷值(Inf),或者数据未经预处理(如归一化/标准化)导致数值范围过大。

4. 损失函数定义问题:例如,在计算交叉熵损失时,对预测值取对数,如果预测值为0或负数(注意:sigmoid输出在0到1之间,但若输入极大或极小,sigmoid输出可能非常接近0或1,取对数时可能得到负无穷,但通常不会出现负数,除非激活函数使用不当)但若使用其他激活函数可能导致输出为负。

5. 除零操作:在损失函数或梯度计算中可能存在除以零的风险。

6. 数值不稳定:例如,在softmax函数中,如果输入值非常大,可能导致指数运算溢出。

排查方法:

1. 降低学习率:尝试将学习率减小一个数量级,观察是否还会出现NaN。

2. 梯度裁剪:在反向传播过程中,对梯度进行裁剪,限制梯度的最大值,防止梯度爆炸。

3. 检查数据:确认输入数据中是否包含NaN或Inf。可以使用numpy或pandas检查数据。

4. 数据预处理:对输入数据进行归一化或标准化,使其数值范围在一个合理的区间(例如,0均值1方差,或者缩放到[0,1]区间)。

5. 检查损失函数:确保在计算损失函数时,对数值稳定性做了处理。例如,在计算交叉熵损失时,可以对预测值进行裁剪,避免取对数时出现无穷大(如将预测值限制在[epsilon, 1-epsilon]之间,其中epsilon是一个很小的正数)。

6. 使用调试工具:在训练过程中打印出损失、梯度、参数等中间变量的值,定位NaN第一次出现的位置。例如,可以在每个epoch后打印损失,或者使用调试器逐步执行。

7. 初始化参数:检查模型参数的初始化是否合适。不合适的初始化(如初始值过大)可能导致输出非常大。

8. 使用正则化:添加正则化项(如L2正则化)可能有助于防止参数变得过大。

针对二分类模型y = \sigma(w_0 + w_1 x)的特定情况,由于模型简单,可能的原因主要是学习率过高、数据问题或损失计算中的对数问题。可以检查学习率,确保输入数据没有异常,并在计算交叉熵损失时对预测值进行裁剪。

http://www.dtcms.com/a/354320.html

相关文章:

  • CSS(面试)
  • Mojomox-在线 AI logo 设计工具
  • 从“流量焦虑”到“稳定增长”:用内容平衡术解锁Adsense变现新密码
  • 电子器械如何统一管理系统权限?一场IT治理的“攻坚战”
  • 第二十九天:重载、重写和覆盖
  • 【网络】iptables MASQUERADE作用
  • 机器学习与Backtrader结合量化交易
  • 无人机抗干扰技术要点解析
  • O2OA移动办公 × Flutter:一次开发,跨平台交付
  • 【C++】深入解析构造函数初始化
  • Docker 镜像重命名【打上新的标签】
  • AI应用图文解说--百度智能云实现语音聊天
  • Python爬虫获取1688商品列表与图片信息
  • 【免驱】一款基于AI8H2K08U芯片的USB转RS485模块,嵌入式工程师调试好帮手!
  • Web 自动化测试常用函数实战(一)
  • 如何防御安全标识符 (SID) 历史记录注入
  • 嵌入式学习day38
  • 怎样选择合适的报表系统?报表系统的主要功能有什么
  • PLC_博图系列☞基本指令”S_PULSE:分配脉冲定时器参数并启动“
  • PyTorch闪电入门:张量操作与自动微分实战
  • Wxml2Canvas在组件中的渲染获取不到元素问题
  • vue 海康视频插件
  • Java Spring Boot 集成淘宝 SDK:实现稳定可靠的商品信息查询服务
  • AI鱼塘,有你画的鱼吗?
  • 代码随想录刷题Day44
  • IDEA连接阿里云服务器中的Docker
  • 嵌入式学习日志————DMA直接存储器存取
  • 微信开发者工具中模拟调试现场扫描小程序二维码功能
  • Centos 7.6离线安装docker
  • 元宇宙+RWA:2025年虚拟资产与真实世界的金融融合实验