当前位置: 首页 > news >正文

深度学习3.7 softmax回归的简洁实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.7.1 初始化模型参数

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

3.7.2 重新审视Softmax的实现

loss = nn.CrossEntropyLoss(reduction='none')

3.7.3 优化算法

# 在这里,我们(使用学习率为0.1的小批量随机梯度下降作为优化算法)
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

3.7.4 训练

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述

3.7.5 预测

batch_size = 256 #迭代器批量
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)def predict_ch3(net, test_iter, n=6):  """Predict labels (defined in Chapter 3)."""for X, y in test_iter:  # 获取第一批测试数据breaktrues = d2l.get_fashion_mnist_labels(y)  # 真实标签转文本preds = d2l.get_fashion_mnist_labels(d2l.argmax(net(X), axis=1))  # 预测标签转文本titles = [true +'\n' + pred for true, pred in zip(trues, preds)]  # 组合标签d2l.show_images(d2l.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])  # 可视化predict_ch3(net, test_iter)

在这里插入图片描述

相关文章:

  • 基于大模型的食管平滑肌瘤全周期预测与诊疗方案研究
  • Kaamel白皮书:Model Context Protocol (MCP) 隐私安全最佳实践
  • 沁恒CHV203中断嵌套导致修改线程栈-韦东山
  • 什么是IT人力外包?IT人力外包服务流程分为哪些步骤?
  • 序论文42 | patch+MLP用于长序列预测
  • Python基础语法:标识符,运算符,数据输入input(),数据输出print(),转义字符,续行符
  • CompletableFuture到底怎么用?
  • 飞算 JavaAI 的 “需求变更” 解决方案:让开发更灵活!
  • 如何解决PyQt从主窗口打开新窗口时出现闪退的问题
  • ai人才需要掌握什么
  • linux 桌面环境
  • JCE cannot authenticate the provider BC
  • 三国杀专业分析面板,立志成为桌游界的stockfish
  • Git多人协作与企业级开发模型
  • AXOP34032: 40V/40µA 轨到轨输入输出双通道运算放大器
  • 如何在windows10上英伟达gtx1060上部署通义千问-7B-Chat
  • 嵌入式:Linux系统应用程序(APP)启动流程概述
  • rk3588 驱动开发(三)第五章 新字符设备驱动实验
  • 算法设计与分析(基础)
  • 抽象类相关
  • 碧桂园服务:拟向杨惠妍全资持有的公司提供10亿元贷款,借款将转借给碧桂园用作保交楼
  • 陕西礼泉一村民被冒名贷款40余万,法院发现涉嫌经济犯罪驳回起诉
  • 被炒热的“高潮针”:超适应症使用,安全性和有效性存疑
  • 黔西游船倾覆事故84名落水人员已全部找到,10人不幸遇难
  • 特朗普:对所有在国外制作进入美国的电影征收100%关税
  • 这 3 种食物,不要放进微波炉!第 1 个就大意了