当前位置: 首页 > news >正文

Class5多层感知机的从零开始实现

Class5多层感知机的从零开始实现

import torch
from torch import nn
from d2l import torch as d2l
# 设置批量大小为256
batch_size = 256
# 初始化训练集和测试集迭代器,每次训练一个批量
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
# 构建一个单隐藏层的前馈神经网络(MLP)
# num_inoputs:输入维度(28*28展平)
# num_outputs:输出维度(10分类)
# num_hiddens:隐藏层神经元数量
num_inputs,num_outputs,num_hiddens = 784,10,256# 第1个全连接层(输入->隐藏层)
# 创建大小为[784,256]的权重矩阵
# nn.Parameter:参与反向传播和优化器更新
# torch.randn:标准正态分布,均值为0,标准差为1
# *0.01:控制初始值大小
W1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))# 第2个全连接层(隐藏层->输出层)
# randn适合初始化权重,zeros适合初始化偏置
W2 = nn.Parameter(torch.randn(num_hiddens,num_outputs,requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs,requires_grad=True))# 参数统一存放到列表中
params = [W1,b1,W2,b2]
# 手动定义relu函数
def relu(X):# 创建一个形状相同的全0张量a = torch.zeros_like(X)# 逐元素返回最大值return torch.max(X,a)
# 定义前馈神经网络(MLP)
def net(X):# X设置为[batch_size,784]的二维矩阵X = X.reshape((-1,num_inputs))# 隐藏层线性变换+ReLU激活函数[batch_size,256]H = relu(X @ W1 + b1)# 返回输出层[batch_size,10]return (H @ W2 + b2)# 添加伪装接口,使其兼容 d2l.train_ch3
net.train = lambda: None
net.eval = lambda: None
# 定义前馈神经网络(MLP)
def net(X):# X设置为[batch_size,784]的二维矩阵X = X.reshape((-1,num_inputs))# 隐藏层线性变换+ReLU激活函数[batch_size,256]H = relu(X @ W1 + b1)# 返回输出层[batch_size,10]return (H @ W2 + b2)# 添加伪装接口,使其兼容 d2l.train_ch3
net.train = lambda: None
net.eval = lambda: None
# 定义交叉熵损失函数
loss = nn.CrossEntropyLoss(reduction='none')
# 设置训练轮数和学习率
num_epochs,lr = 10,0.1
# 定义随机梯度下降优化器
updater = torch.optim.SGD(params,lr=lr)
# 调用d2l的训练函数
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,updater)
http://www.dtcms.com/a/266928.html

相关文章:

  • Linux awk 命令
  • 浅谈 webshell 构造之如何获取恶意函数
  • chrome插件合集
  • 4 位量化 + FP8 混合精度:ERNIE-4.5-0.3B-Paddle本地部署,重新定义端侧推理效率
  • 【LUT技术专题】CLUT代码讲解
  • 写一个Ununtu C++ 程序,调用ffmpeg API, 来判断一个数字电影的视频文件mxf 是不是Jpeg2000?
  • MSPM0G3507学习笔记(一) 重置版:适配逐飞库的ti板环境配置
  • 服装零售企业跨区域运营难题破解方案
  • 深度学习笔记29-RNN实现阿尔茨海默病诊断(Pytorch)
  • 25年Java后端社招技术场景题!
  • MyDockFinder 绿色便携版 | 一键仿Mac桌面,非常简单
  • 应用分发平台的重要性:构建、扩展和管理您的移动应用
  • VR 火化设备仿真系统具备哪些优势?​
  • MySQL 八股文【持续更新ing】
  • 机器学习路径规划中的 net 和 netlist 分别是什么?
  • 《推客分销系统架构设计:从零搭建高并发社交裂变引擎》
  • linux---------------进程信号(下)
  • 将制作的网站部署在公网
  • 电机转速控制系统算法分析与设计
  • 同步(Synchronization)和互斥(Mutual Exclusion)关系
  • 基于Apache MINA SSHD配置及应用
  • Python爬虫 模拟登录状态 requests版
  • 如何查看自己电脑的CUDA版本?
  • D3 面试题100道之(21-40)
  • 通过MaaS平台免费使用大模型API
  • Java 入门
  • 鸿蒙中判断两个对象是否相等
  • react案例动态表单(受控组件)
  • React 渲染深度解密:从 JSX 到 DOM 的初次与重渲染全流程
  • 深入解析XFS文件系统:原理、工具与数据恢复实战