当前位置: 首页 > news >正文

深度学习(斋藤康毅)学习笔记(六)反向传播3

上一篇文章介绍了反向传播的自动化,但也存在一些问题,本章用于说明这些问题,并修改原有框架,使其支持复杂计算图的运行:

问题一:重复使用一个变量,梯度不会累计

也就是说,反向传播时gx=[1,1],f.inputs=[x,x],在循环中第二次赋值把第一次赋的值给覆盖了。

如何解决

问题二:重复计算导数,梯度不会自动清零

第一次计算,x的导数是2,第二次计算x的导数是3。但是两次的导数值累加了。

设置x.cleargrad() 手动清空梯度

问题三:不支持复杂计算图

循环funs中的函数
0(初始化)[D]
1[B,C]
2[B,A]
3[B]
4[A]
5[ ]

如何解决

代码验证

完整代码

import numpy as np

class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1

    def cleargrad(self):
        self.grad = None

    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        funcs = []
        seen_set = set()

        def addfunc(f):
            if f not in seen_set:
                funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)#funcs按照从小到大的辈分排序

        addfunc(self.creator)

        while funcs:
            f = funcs.pop()
            gys = [output.grad for output in f.outputs]  # 获取所有输出的梯度
            gxs = f.backward(*gys)                       # 调用 backward 方法
            if not isinstance(gxs, tuple):               # 确保 gxs 是元组
                gxs = (gxs,)
            for x, gx in zip(f.inputs, gxs):            # 为每个输入分配梯度
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                if x.creator is not None:
                    addfunc(x.creator)

class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]               # 提取输入数据
        ys = self.forward(*xs)                      # 前向传播(解包)
        if not isinstance(ys, tuple):               # 确保 ys 是元组
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]  # 创建输出变量

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs                        # 保存输入
        self.outputs = outputs                      # 保存输出
        return outputs if len(outputs) > 1 else outputs[0]  # 根据输出数量返回

    def forward(self, *xs):
        raise NotImplementedError()

    def backward(self, *gys):
        raise NotImplementedError()

# 实现具体的函数类
class Square(Function):
    def forward(self, x):
        return x ** 2

    def backward(self, gy):
        x = self.inputs[0].data                     # 从 inputs 中获取数据
        gx = 2 * x * gy
        return gx

class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y

    def backward(self, gy):
        return gy, gy                               # 对两个输入返回相同的梯度

# 定义便捷函数
def square(x):
    return Square()(x)

def add(x0, x1):
    return Add()(x0, x1)

# 定义 as_array 函数
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

x = Variable(np.array(2.0))
a = square(x)
y = add(square(a) , square(a))
y.backward()
print(y.data)
print(x.grad)
'''
# 测试代码
x = Variable(np.array(2.0))
y = Variable(np.array(3.0))
z = add(square(x), square(y))
z.backward()
print(z.data)    # 输出结果: 13.0 (2^2 + 3^2 = 4 + 9 = 13)
print(x.grad)    # 输出梯度: 4.0 (dz/dx = 2 * 2 = 4)
print(y.grad)    # 输出梯度: 6.0 (dz/dy = 2 * 3 = 6)
'''

运行结果:

下一章将会深入实践深度学习中的内存管理部分,掌握代码技术

相关文章:

  • 面试中常问的mysql数据库指令【杭州多测师_王sir】
  • 盛铂科技 FlexDDS - NG波形发生器(直接数字信号合成器(DDS)):量子光学研究的得力助手
  • HTML学习笔记(全)
  • 第三章:go 依赖管理 go get / go get tidy
  • Windows应用访问 WSL中服务的5 种选择方案
  • 第一:goland安装
  • 嵌入式开发之串行数据处理
  • 计算机毕业设计SpringBoot+Vue.js疗养院管理系统(源码+文档+PPT+讲解)
  • AI如何重塑运维体系
  • fastapi房产销售系统
  • Elastic如何获取当前系统时间
  • Vue项目通过内嵌iframe访问另一个vue页面,获取token适配后端鉴权(以内嵌若依项目举例)
  • uniapp 微信小程序 升级 uniad插件版本号
  • 量子状态优化:探索量子计算的新维度
  • Grafana
  • Redis maven项目 jedis 客户端操作(一)
  • 《Python实战进阶》No13: NumPy 数组操作与性能优化
  • 点云软件VeloView开发环境搭建与编译
  • ubuntu22.04机器人开发环境配置
  • 使用Wireshark截取并解密摄像头画面
  • 中国建设传媒网官网/徐州seo顾问
  • 谷歌优化seo/网站优化公司认准乐云seo
  • 有了网站 怎么做排名优化/什么软件可以推广自己的产品
  • 丹江口网站开发/济南计算机培训机构哪个最好
  • 中小网站公司做的推广怎么样/关键词你们都搜什么
  • 企业如何在自己的网站上做宣传/江北seo