当前位置: 首页 > news >正文

深度学习(斋藤)学习笔记(五)-反向传播2

上一篇关于反向传播的代码仅支持单变量的梯度计算,下面我们将扩展代码使其支持多个输入/输出。增加了对多输入函数(如 Add),以实现的计算。

1.关于前向传播可变长参数的改进-修改Function类

修改方法:

Function用于对输入输出做规定,帮助实现右图的效果(接受inputs 返回outputs):

2.关于反向传播可变长参数的改进

修改函数类的反向传播

修改Variable类的反向传播

改进前:

获取y.creator,获取输入creator.inputs,根据y.grads计算x.grads:creator.backward(y.grads)

2.3两步的解包和打包操作:

最后修改square方法:

完整代码

import numpy as np

class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func

    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)

        funcs = [self.creator]
        while funcs:
            f = funcs.pop()
            gys = [output.grad for output in f.outputs]  # 获取所有输出的梯度
            gxs = f.backward(*gys)                       # 调用 backward 方法
            if not isinstance(gxs, tuple):               # 确保 gxs 是元组
                gxs = (gxs,)
            for x, gx in zip(f.inputs, gxs):            # 为每个输入分配梯度
                x.grad = gx
                if x.creator is not None:
                    funcs.append(x.creator)

class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]               # 提取输入数据
        ys = self.forward(*xs)                      # 前向传播(解包)
        if not isinstance(ys, tuple):               # 确保 ys 是元组
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]  # 创建输出变量

        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs                        # 保存输入
        self.outputs = outputs                      # 保存输出
        return outputs if len(outputs) > 1 else outputs[0]  # 根据输出数量返回

    def forward(self, *xs):
        raise NotImplementedError()

    def backward(self, *gys):
        raise NotImplementedError()

# 实现具体的函数类
class Square(Function):
    def forward(self, x):
        return x ** 2

    def backward(self, gy):
        x = self.inputs[0].data                     # 从 inputs 中获取数据
        gx = 2 * x * gy
        return gx

class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y

    def backward(self, gy):
        return gy, gy                               # 对两个输入返回相同的梯度

# 定义便捷函数
def square(x):
    return Square()(x)

def add(x0, x1):
    return Add()(x0, x1)

# 定义 as_array 函数
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

# 测试代码
x = Variable(np.array(2.0))
y = Variable(np.array(3.0))
z = add(square(x), square(y))
z.backward()
print(z.data)    # 输出结果: 13.0 (2^2 + 3^2 = 4 + 9 = 13)
print(x.grad)    # 输出梯度: 4.0 (dz/dx = 2 * 2 = 4)
print(y.grad)    # 输出梯度: 6.0 (dz/dy = 2 * 3 = 6)

运行结果:

相关文章:

  • 平面机械臂运动学分析
  • 如何高效地找工作?
  • tomcat单机多实例部署
  • 2025年渗透测试面试题总结-腾某讯-技术安全实习生(题目+回答)
  • 使用XShell连接RHEL9并配置yum阿里源
  • 使用express创建服务器保存数据到mysql
  • linux安装nginx
  • 【前端基础】Day 10 CSS3-2D3D
  • Visual Studio Code for SAP (SAP PRESS) (Leon Hassan)
  • Vue中常见动画执行详解
  • 数据库高级面试题
  • 第六课:数据库集成:MongoDB与Mongoose技术应用
  • javaweb:Maven、SpringBoot快速入门、HTTP协议
  • OpenCV视频解码性能优化十连击(实测帧率提升300%)
  • Java数据结构:解构排序算法的艺术与科学(一)
  • 光通信产业链分析
  • 第五课:Express框架与RESTful API设计:技术实践与探索
  • 动物摄像头监测识别AI技术结合了摄像头监测与人工智能识别(新产品)
  • 机动车授权签字人考试题库及答案
  • 青少年编程与数学 02-010 C++程序设计基础 30课题、操作符重载
  • 南京建站推广公司/seo培训赚钱
  • perl网站建设/搜索引擎优化是什么
  • 网站建设服务费怎么做会计分录/聊城优化seo
  • 加工厂网站建设/信息流广告的特点
  • 寿光专业做网站的公司/西安优化排名推广
  • 50强网站建设公司/免费seo工具汇总