当前位置: 首页 > news >正文

反向传播

        如果你要用梯度下降算法来训练一个神经网络,应该怎么做?

        假设网络有一堆的参数:\theta =\begin{Bmatrix} w_{1},w_{2},...,b_{1},b_{2},... \end{Bmatrix}。首先选择一个初始的参数\theta ^{0},计算\theta ^{0}对损失函数的梯度,也就是计算神经网络里面的参数w_{1},w_{2},...,b_{2},b_{2},...对loss损失函数的导数,计算出后,更新参数\theta ^{1}=\theta ^{0}-\eta \bigtriangledown L(\theta ^{0});再计算\theta ^{1}对损失函数的梯度,再更新\theta ^{2},以此类推.......

        我们会定义一个loss损失函数,这个损失函数就是所有训练样本的预测值与真实值之间差值和,对损失函数进行梯度下降算法的公式如下:

         现在我们来看一下怎么对某一笔的样本计算梯度。

         对于上述的神经元,先考虑计算某一个神经元的梯度:

        通过前向过程计算可得z=x_{1}w_{1}+x_{2}w_{2}+b,对于\frac{\partial C}{\partial w}=\frac{\partial C}{\partial z}\frac{\partial z}{\partial w}。计算\frac{\partial z}{\partial w}是前向过程,计算\frac{\partial C}{\partial z}是反向过程。

        我们先来看下怎么计算\frac{\partial z}{\partial w}  。因为z=x_{1}w_{1}+x_{2}w_{2}+b\frac{\partial z}{\partial w_{1}}=x_{1}\frac{\partial z}{\partial w_{1}}=x_{2}。对于\frac{\partial z}{\partial w}就是看这个w前面接的是什么,那微分以后就是什么。w_{1}前面接的输入是x_{1},所以求导后就是x_{1}w_{2}前面接的输入是x_{2},所以求导后就是x_{2},就是这样的规律。

        假如给你如下图的神经网络,它里面有一大堆的参数,计算里面的\frac{\partial z}{\partial w},这件事非常容易。

        如果有人想问你: 这个\frac{\partial z}{\partial w}是多少,你看这个w=1前面接的输入是-1,你可以瞬间告诉他\frac{\partial z}{\partial w}=-1

        接下来,有人想问你:,对于这个w=-1,\frac{\partial z}{\partial w}是多少,你可以很快告诉他\frac{\partial z}{\partial w}=0.12

        知道了怎么计算\frac{\partial z}{\partial w},我们现在来看看怎么计算\frac{\partial C}{\partial z}。计算\frac{\partial C}{\partial z}你会觉得很困难,因为z通过激活函数后得到一个输出。

      假设激活函数时sigmoid函数,z通过sigmoid函数后得到a。我们知道\frac{\partial C}{\partial z}=\frac{\partial C}{\partial a}\frac{\partial a}{\partial z}a=\sigma (z)\frac{\partial a}{\partial z}就是sigmoid函数的偏微分。sigmoid函数如下图绿色线所示,它的微分如蓝色线所示:

因为a会影响后面的{z}'{z}'会影响C;a会影响后面{z}''{z}''会影响C。所以\frac{\partial C}{\partial a}=\frac{\partial {z}'}{\partial a}\frac{\partial C}{\partial {z}'}+\frac{\partial {z}''}{\partial a}\frac{\partial C}{\partial {z}"}。因为{z}'=aw_{3}+aw_{4}+...,所以我们能很快知道\frac{\partial {z}'}{\partial a}=w_{3}\frac{\partial {z}''}{\partial a}=w_{4},但是我们又很难计算\frac{\partial C}{\partial {z}'}\frac{\partial C}{\partial {z}''},因为神经网络后面可能会又其他的运算,在此,我们先假设知道\frac{\partial C}{\partial {z}'}\frac{\partial C}{\partial {z}''}这两项的值。 现在我们就可以计算\frac{\partial C}{\partial z}= {\sigma}'(z)\begin{bmatrix} w_{3}\frac{\partial C}{\partial {z}'}+w_{4}\frac{\partial C}{\partial {z}''} \end{bmatrix}的值。

        我们可以从另一个观点看待这个式子,如下图,其中{\sigma }'(z)是一个常数,因为z在计算前向过程的时候就被决定好了。

        回到上一个问题,我们要怎么算\frac{\partial C}{\partial {z}'}\frac{\partial C}{\partial {z}''}呐?

                第一个例子是我们假设橘色的这两个神经元是输出层,所以可以计算出\frac{\partial C}{\partial {z}'}=\frac{\partial C}{\partial y_{1}}\frac{\partial y_{1}}{\partial {z}'} ,\frac{\partial C}{\partial {z}''}=\frac{\partial C}{\partial y_{2}}\frac{\partial y_{2}}{\partial {z}''}

        假设橘色的神经元并不是整个神经网络的输出,它后面还有其他的层,那应该怎么算呢?

        如果我们知道\frac{\partial C}{\partial z_{a}}\frac{\partial C}{\partial z_{b}},我们就能计算出\frac{\partial C}{\partial {z}'},但我们现在无法计算出\frac{\partial C}{\partial z_{a}}\frac{\partial C}{\partial z_{b}},因为我们不知道后续的层是什么样的。我们可以再往下一层去看,如果绿色的神经元是输出层的话,计算\frac{\partial C}{\partial z_{a}}\frac{\partial C}{\partial z_{b}}就不成问题。

        实际上,我们是从输出层的\frac{\partial C}{\partial z}开始计算的:

        假设我们现在有6个神经元,现在我们要计算\frac{\partial C}{\partial z},如果先计算\frac{\partial C}{\partial z_{1}}\frac{\partial C}{\partial z_{2}},那就没有效率;如果先算\frac{\partial C}{\partial z_{5}}\frac{\partial C}{\partial z_{6}},就很有效率。

        算出\frac{\partial C}{\partial z_{5}}\frac{\partial C}{\partial z_{6}}后,就可以算出\frac{\partial C}{\partial z_{3}}\frac{\partial C}{\partial z_{4}},然后算出\frac{\partial C}{\partial z_{1}}\frac{\partial C}{\partial z_{2}}

        实际上,这个过程如下图所示:

相关文章:

  • 2、ubantu系统配置OpenSSH | 使用vscode或pycharm远程连接
  • 软件设计师考试《综合知识》CPU考点分析(2019-2023年)——求三连
  • 【QT 项目部署指南】使用 Inno Setup 打包 QT 程序为安装包(超详细图文教程)
  • 基于EFISH-SCB-RK3576/SAIL-RK3576的消防机器人控制器技术方案‌
  • Linux云计算训练营笔记day09(MySQL数据库)
  • 进度管理高分论文
  • 在 Hugo 博客中集成评论系统 Waline 与 浏览量统计
  • 基于“物理—事理—人理”的多源异构大数据融合探究
  • bfs搜索加标记连通区域id实现时间优化(空间换时间)
  • Go语言八股之Mysql事务
  • 扬州卓韵酒店用品:优质洗浴用品,提升酒店满意度与品牌形象
  • TCP(传输控制协议)建立连接的过程
  • Git/GitLab日常使用的命令指南来了!
  • 前端代码生成博客封面图片
  • 寻找两个正序数组的中位数 - 困难
  • 【BotSharp详细介绍——一步步实现MCP+LLM的聊天问答实例】
  • vscode c++编译onnxruntime cuda 出现的问题
  • 浏览器宝塔访问不了给的面板地址
  • 运维职业发展思维导图
  • 幼儿学前教育答辩词答辩技巧问题答辩自述稿
  • 浙江省台州市政协原副主席林虹被“双开”
  • 国家卫健委通报:吊销肖某医师执业证书,撤销董某莹四项证书
  • 证监会:2024年依法从严查办证券期货违法案件739件,作出处罚决定592件、同比增10%
  • 制造四十余年血腥冲突后,库尔德工人党为何自行解散?
  • 李强会见巴西总统卢拉
  • 独行侠以1.8%概率获得状元签,NBA原来真的有剧本?