当前位置: 首页 > news >正文

Class10简洁实现

Class10简洁实现

import torch
from torch import nn
from d2l import torch as d2l
# 输入为28*28,输出为10类,第1、2隐藏层256神经元
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 第1个隐藏层丢弃率为0.2,第2个为0.5
dropout1, dropout2 = 0.2, 0.5
# nn.Flatten():把28*28展平为784
net = nn.Sequential(nn.Flatten(),# 输入层->第1隐藏层nn.Linear(784, 256),# ReLU激活nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),# 第1隐藏层->第2隐藏层nn.Linear(256, 256),# ReLU激活nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),# 第2隐藏层->输出10类nn.Linear(256, 10))# 初始化权重函数
def init_weights(m):# 判断如果为线性if type(m) == nn.Linear:# 正态分布初始化,均值为0,标准差为0.01nn.init.normal_(m.weight, std=0.01)
# 若为nn.Linear,则调用init_weight函数进行初始化
net.apply(init_weights);
# 设置训练轮数,学习率,批次大小
num_epochs,lr,batch_size = 10,0.5,256
# 定义损失函数,并保留每个样本损失
loss = nn.CrossEntropyLoss(reduction='none')
# 加载训练集和测试集
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
# 设置SGD随机梯度下降优化器
trainer = torch.optim.SGD(net.parameters(),lr=lr)
# 调用训练主函数
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
http://www.dtcms.com/a/292221.html

相关文章:

  • 图解Spring的循环依赖
  • 2025茶吧机语音控制集成方案
  • 深入解析Hadoop中的推测执行:原理、算法与策略
  • 【华为机试】684. 冗余连接
  • Python编程进阶知识之第三课处理数据(numpy)
  • LSTM+Transformer炸裂创新 精准度至95.65%
  • 【C++】复习重点-汇总2-面向对象(三大特性、类/对象、构造函数、继承与派生、多态、抽象类、this/对象指针、友元、运算符重载、static、类/结构体)
  • vscode gdb调试c语言过程
  • IDEA-自动格式化代码
  • IDEA全局Maven配置
  • 【IDEA】如何在IDEA中通过git创建项目?
  • 【C++】nlohmann/json
  • 哔哩哔哩视觉算法面试30问全景精解
  • Kafka单条消息长度限制详解及Java实战指南
  • 新品如何通过广告投放精准获取流量实现快速增长
  • 【RAG优化】PDF复杂表格解析问题分析
  • 北宋政治模拟(deepseek)
  • 力扣面试150题--寻找峰值
  • 如何为每个参数案例自动执行当前数据集
  • 双指针算法介绍及使用(上)
  • rk3568平台记录一次推流卡顿分析过程
  • Next.js项目目录结构详解:从入门到精通的最佳实践指南
  • 一文详解策略梯度算法(REINFORCE)—强化学习(8)
  • 新手向:基于Python的剪贴板历史增强工具
  • Jiasou TideFlow AIGC SEO Agent:全自动外链构建技术重构智能营销新标准
  • 数据库 × 缓存双写策略深度剖析:一致性如何保障?
  • Apache Ignite缓存基本操作
  • Redis原理之缓存
  • uni-calendar自定义签到打卡颜色
  • Java-79 深入浅出 RPC Dubbo Dubbo 动态路由架构详解:从规则设计到上线系统集成