当前位置: 首页 > news >正文

吴恩达机器学习作业六:反向传播

数据集在作业一

反向传播

反向传播的本质是利用链式法则(Chain Rule)高效计算梯度。具体来说:

  1. 前向传播(Forward Pass):先将输入数据送入网络,逐层计算神经元的输出,最终得到模型的预测结果,并根据预测值与真实标签计算损失函数(衡量预测误差的指标)。

  2. 反向传播(Backward Pass):从损失函数开始,反向逐层计算损失对每个参数的梯度(即损失随参数变化的速率)。梯度的方向指示了参数需要调整的方向(增加或减少),梯度的大小指示了调整的幅度。

  3. 参数更新:根据反向传播得到的梯度,使用优化器(如梯度下降法)更新网络中的权重和偏置,以降低损失。

大家如果借助线性回归的梯度下降法来理解就是参数w在正向传播算出损失(这里只有一层,所以可以不用计算损失,只算梯度,因为这里的损失是和前面的参数更新有关),然后反向传播算出更新的w,再正向传播,反向传播,一直反复,达到次数。

具体算法

这里我们可以尝试数学推导来理解一下为什么说前一层损失可以用这一层的更新

可以把图丢给ai来分析理解。

代码

读取数据

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.optimize import minimize# 读取数据
data=sio.loadmat("ex4data1.mat")
X=data['X']
y=data['y']
# print(X.shape,y.shape)(5000, 400) (5000, 1)
theta=sio.loadmat("ex4weights.mat")
theta1=theta['Theta1']
theta2=theta['Theta2']

数据预处理(one-hot)

# 数据预处理,对y进行one-hot编码,即将y的值转换成0-9的矩阵,每一列只有一个1代表所对应正确编码,其余为0
def one_hot(y,n):return np.array((y-1)==np.arange(n)).astype(int)
y=one_hot(y,10)
# print(y,y.shape)

由于我们要进行反向传播,那么我们必须有损失,那么我们通过将单个y值替换成向量有利于我们后面的计算。

序列化和解序列化参数

def serialize_params(theta1,theta2):return np.concatenate((theta1.flatten(),theta2.flatten()))
# print(serialize_params(theta1,theta2).shape)
def deserialize_params(serialized_params):theta1=serialized_params[:25*401].reshape(25,401)theta2=serialized_params[25*401:].reshape(10,26)return theta1,theta2theta_serialized=serialize_params(theta1,theta2)
theta1,theta2=deserialize_params(theta_serialized)
# print(theta1.shape,theta2.shape)

这里的序列化和解序列化就是把高维的参数变成低维并连接在一起,这样可以满足一些函数的输入要求,而后者是为了把参数变成正常状态。

激活函数

def sigmoid(z):return 1/(1+np.exp(-z))

前向传播

def forward_propagate(X,theta):m=X.shape[0]theta1,theta2=deserialize_params(theta)a1=np.insert(X,0,1,axis=1)z2=np.dot(a1,theta1.T)a2=sigmoid(z2)a2=np.insert(a2,0,1,axis=1)z3=np.dot(a2,theta2.T)a3=sigmoid(z3)return a1,z2,a2,z3,a3                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              

损失函数(分为带正则项与不带正则项)

# 不带正则项
def cost_function(theta,X,y):m=X.shape[0]a1,z2,a2,z3,a3=forward_propagate(X,theta)J=-1/m*np.sum(y*np.log(a3)+(1-y)*np.log(1-a3))return J# 带正则项
# 不改变theta1的第一列
def cost_function_reg(theta,X,y,lamda):m=X.shape[0]theta1,theta2=deserialize_params(theta)a1,z2,a2,z3,a3=forward_propagate(X,theta)# 防止 log(0) 或 log(1)a3 = np.clip(a3, 1e-10, 1 - 1e-10)J=cost_function(theta,X,y)+(lamda/(2*m))*(np.sum(theta1[:,1:]**2)+np.sum(theta2[:,1:]**2))return J# print(cost_function_reg(theta1,theta2,X,y,1))

之前已经讲过正则化了,这里就不赘述了。

反向传播(梯度)

# 激活函数梯度
def sigmoid_gradient(z):z = np.clip(z, -500, 500)  # 防止数值溢出return sigmoid(z)*(1-sigmoid(z))
# 无正则化梯度
def cost_function_grad(theta,X,y):a1,z2,a2,z3,a3=forward_propagate(X,theta)d3=a3-ytheta1,theta2=deserialize_params(theta)d2=np.dot(d3,theta2[:,1:])*sigmoid_gradient(z2)D2=np.dot(d3.T,a2)/len(X)D1=np.dot(d2.T,a1)/len(X)return serialize_params(D1,D2)
# 有正则化梯度
def cost_function_grad_reg(theta,X,y,lamda):theta1,theta2=deserialize_params(theta)D=cost_function_grad(theta,X,y)D1,D2=deserialize_params(D)reg1=theta1[:,1:]*lamda/len(X)reg2=theta2[:,1:]*lamda/len(X)D1[:,1:]=D1[:,1:]+reg1D2[:,1:]=D2[:,1:]+reg2return serialize_params(D1,D2)

训练参数

def nn_train(X,y):init_theta = np.random.uniform(-0.1, 0.1, 10285)res=minimize(fun=cost_function_reg,x0=init_theta,args=(X,y,10),method='TNC',jac=cost_function_grad_reg,options={'maxfun':300})return res.x

测试

theta=nn_train(X,y)
_,_,_,_,a3=forward_propagate(X,theta)
y_pred=np.argmax(a3,axis=1)+1
y_original=np.argmax(y,axis=1)+1
# print(y_pred.shape,y_original.shape)acc=np.mean(y_pred==y_original)
print(acc)#0.9418

总结

读入数据——数据预处理——序列化和反序列化构建——前向传播和激活函数——损失函数——梯度(反向传播)——调用优化器训练参数——测试

http://www.dtcms.com/a/357842.html

相关文章:

  • 三一重工AI预测性维护破局:非计划停机减少60%,技师转型与数字孪生技术搅动制造业
  • 单点登录(SSO)
  • 2.ImGui-搭建一个外部绘制的窗口环境(使用ImGui绘制一个空白窗口)
  • 从零开始学Shell编程:从基础到实战案例
  • 再来,一次内存溢出
  • 【人工智能99问】参数调整技术(31/99)
  • 【Spring Cloud Alibaba】前置知识(一)
  • RAG教程6:cohere rerank重排
  • 物理AI:连接数字智能与物理世界的下一代人工智能范式
  • 函数的逆与原象
  • 【完整源码+数据集+部署教程】传送带建筑材料识别系统源码和数据集:改进yolo11-AFPN-P345
  • vue3 表单项不对齐的解决方案
  • gpu与cpu各厂商的优劣
  • 【系列01】端侧AI:构建与部署高效的本地化AI模型
  • 【编号513】2025年全国地铁矢量数据
  • PCIe 6.0的速度奥秘:数学视角下的编码革命与信号完整性突破
  • 永磁同步电机无速度算法--传统脉振方波注入法(2)
  • Linux系统编程—进程概念
  • 疯狂星期四文案网第54天运营日记
  • 动态规划--Day03--打家劫舍--198. 打家劫舍,213. 打家劫舍 II,2320. 统计放置房子的方式数
  • Android系统框架知识系列(十九):Android安全架构深度剖析 - 从内核到应用的全栈防护
  • 深入解析Paimon MergeFunction
  • 图解帕累托前沿(pareto frontier)
  • 嵌入式Linux驱动开发:i.MX6ULL按键中断驱动(非阻塞IO)
  • stm32单片机使用tb6612驱动编码器电机并测速的驱动代码详解—详细参考开发手册(可移植+开发手册)
  • 文本嵌入模型的本质
  • 《ArkUI 记账本开发:状态管理与数据持久化实现》
  • 分布式锁在支付关闭订单场景下的思考
  • Product Hunt 每日热榜 | 2025-08-29
  • 逻辑漏洞 跨站脚本漏洞(xss)