当前位置: 首页 > news >正文

soft回归用内置函数

动手学深度学习v2-3.7-学习-笔记

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

分析:引入库,设置批量大小,加载Fashion MNIST数据集

  • torch:PyTorch 是一个流行的深度学习框架,提供了张量操作、自动求导等功能。

  • nn:PyTorch 的神经网络模块,包含各种神经网络层和模块。

  • d2ld2l 是李沐等人编写的《动手学深度学习》(Dive into Deep Learning)配套的 Python 包。它提供了许多方便的工具函数,包括数据加载、绘图等。

  • batch_size:批量大小(batch size)是指每次训练时输入模型的数据样本数量。这里设置为 256,意味着每次训练时会从数据集中抽取 256 个样本组成一个批量。

     

    # PyTorch不会隐式地调整输入的形状。因此,
    # 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
    net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
    
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, std=0.01)
    
    net.apply(init_weights);

    分析:定义一个简单的数据集,初始化权重

  • nn.Flatten():这是一个展平层,用于将输入数据从多维张量展平为一维张量。例如,对于 Fashion MNIST 数据集,输入图像的形状是 (28, 28),经过 nn.Flatten() 后,会变成形状为 (784,) 的一维张量。

  • nn.Linear(784, 10):这是一个全连接层(线性层)。它将输入的 784 个特征(展平后的图像数据)映射到 10 个输出类别(Fashion MNIST 数据集有 10 个类别)。

  • if type(m) == nn.Linear:判断当前层是否是全连接层(nn.Linear)。如果是,则对该层的权重进行初始化

  • nn.init.normal_(m.weight, std=0.01):使用正态分布初始化权重。std=0.01 表示权重的初始化标准差为 0.01。正态分布初始化是一种常见的权重初始化方法,有助于避免梯度消失或爆炸问题。
     

    loss = nn.CrossEntropyLoss(reduction='none')

    分析:nn.CrossEntropyLoss 是一个常用的损失函数,用于多分类问题。它结合了 nn.LogSoftmaxnn.NLLLoss(负对数似然损失)的功能,能够计算模型输出的概率分布与真实标签之间的交叉熵损失。

    trainer = torch.optim.SGD(net.parameters(), lr=0.1)

    分析:torch.optim.SGD 是 PyTorch 提供的一个优化器类,用于实现随机梯度下降算法。SGD 是一种常用的优化算法,通过迭代更新模型参数,以最小化损失函数

  • net:表示定义好的神经网络模型。

  • net.parameters():返回模型中所有可训练的参数(权重和偏置)。这些参数是 PyTorch 的 torch.nn.Parameter 对象,它们会被优化器用来更新模型。

import torch
from torch import nn
from d2l import torch as d2l

# 定义网络
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

# 初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights);

# 加载数据集
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')  # 或者使用 reduction='mean'

# 定义优化器
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

# 训练模型
num_epochs = 10

for epoch in range(num_epochs):
    net.train()  # 将模型设置为训练模式
    total_loss = 0.0
    for X, y in train_iter:
        y_hat = net(X)  # 前向传播
        l = loss(y_hat, y).sum()  # 对损失值求和,使其成为一个标量
        l.backward()  # 反向传播
        trainer.step()  # 更新参数
        trainer.zero_grad()  # 清空梯度
        total_loss += l.item()
    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_iter.dataset)}')

相关文章:

  • 软考-高项,知识点一览八 整合管理
  • CUDA Lazy Loading:优化GPU程序初始化与内存使用的利器
  • 【蓝桥杯】12111暖气冰场(多源BFS 或者 二分)
  • ‘闭包‘, ‘装饰器‘及其应用场景
  • 西门子200smart之modbus_TCP(做从站与第三方设备)通讯
  • 从头开始学C语言第二十九天——指针数组
  • JavaScript-日期对象与节点操作详解
  • Apache Flink技术原理深入解析:任务执行流程全景图
  • Rocky9.2 编译安装Intel WIFI系列无线网卡驱动
  • 华为终端将全面进入鸿蒙时代
  • LLM - CentOS上离线部署Ollama+Qwen2.5-coder模型完全指南
  • Mimalloc论文解析:小内存管理的极致追求与实践启示
  • 虚拟机访问主机的plc仿真
  • C++学习之网盘项目单例模式
  • Swift 经典链表面试题:如何在不访问头节点的情况下删除指定节点?
  • FPGA 以太网通信(四)网络视频传输系统
  • c#难点整理2
  • windows下利用Ollama + AnythingLLM + DeepSeek 本地部署私有智能问答知识库
  • CVPR 2025 | 文本和图像引导的高保真3D数字人高效生成GaussianIP
  • 美国国家数据浮标中心(NDBC)
  • 阿里上季度营收增7%:淘天营收创新高,AI产品营收连续七个季度三位数增长
  • 押井守在30年前创造的虚拟世界何以比当下更超前?
  • 欠债七十万后,一个乡镇驿站站长的中年心事
  • 前四个月社会融资规模增量累计为16.34万亿元,比上年同期多3.61万亿元
  • 外交部:中方对美芬太尼反制仍然有效
  • 孙卫东会见巴基斯坦驻华大使:支持巴印两国实现全面持久停火