当前位置: 首页 > news >正文

pytorch学习笔记-Loss的使用、在神经网络中加入Loss、优化器(optimizer)的使用

博主最近真要累鼠了…
anyway上号更新一点,预计下周就能把该系列学完了= =

Loss的使用

具体关注一下官网上对于形状的说明,如果报错就看看是不是形状不符合要求,其他没啥

import torch
from torch.nn import L1Loss,MSELoss, CrossEntropyLoss
from torch import nn#L1Loss
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)l1_loss = L1Loss()
res = l1_loss(inputs, targets)print("L1Loss:",res)#MSE
mse_loss = MSELoss()
res = mse_loss(inputs, targets)print("mseLoss:",res)#crossentropy
x = torch.tensor([0.1,0.2,0.3])#每一类的概率
y = torch.tensor([1])#目标类别编号# Input: Shape(N,C), N:bt_size C:class
x = torch.reshape(x,(1,3))
# print(x)cross_loss = CrossEntropyLoss()
res = cross_loss(x,y)
print("crossentropyLoss:",res)# L1Loss: tensor(0.6667)
# mseLoss: tensor(1.3333)
# crossentropyLoss: tensor(1.1019)

在神经网络中加入Loss

了解Loss的本质就是计算真值和目标值之间的差距后,怎么在神经网络中引入loss也蛮自然的:

import torch
import torch.nn as nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential, CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from torch.utils.data import DataLoaderdata_transforms = transforms.Compose([transforms.ToTensor()
])test_data = datasets.CIFAR10(root="./dataset",train=False,transform=data_transforms)dataloader = DataLoader(test_data,batch_size=64)class myModule(nn.Module):def __init__(self):super().__init__()self.model1 = Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self, x):x = self.model1(x)return xmy_module = myModule()loss = CrossEntropyLoss()for data in dataloader:imgs, targets = dataoutputs = my_module(imgs)res_loss = loss(outputs,targets)print(res_loss)

优化器的使用

引入Loss的目的是为了更好的进行参数更新,因此需要引入优化器
事实上引入这一步后也就基本知道了模型如何进行训练了
一般在一次学习中就进行了多次更新,需要进行多次学习,注意的点就是每次梯度计算优化前需要先将上一轮计算得到的梯度清零,因为上一批次的对本次的结果意义不大,剩下的就是用法:

import torch
import torch.nn as nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential, CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch import optimdata_transforms = transforms.Compose([transforms.ToTensor()
])test_data = datasets.CIFAR10(root="./dataset",train=False,transform=data_transforms)dataloader = DataLoader(test_data,batch_size=64)class myModule(nn.Module):def __init__(self):super().__init__()self.model1 = Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self, x):x = self.model1(x)return xmy_module = myModule()loss = CrossEntropyLoss()#设置优化器
optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)for epoch in range(5):running_loss = 0.0 #计算每一个epoch的lossfor data in dataloader:imgs, targets = dataoutputs = my_module(imgs)res_loss = loss(outputs,targets)# print(res_loss)running_loss+=res_lossoptimizer.zero_grad()#每一次要重新清零上一轮的梯度res_loss.backward()optimizer.step()#进行优化lossprint(running_loss)# tensor(360.6527, grad_fn=<AddBackward0>)
# tensor(355.2374, grad_fn=<AddBackward0>)
# tensor(336.8181, grad_fn=<AddBackward0>)
# tensor(320.9179, grad_fn=<AddBackward0>)
# tensor(312.2012, grad_fn=<AddBackward0>)
http://www.dtcms.com/a/330464.html

相关文章:

  • 基于SpringBoot+Vue的轻手工创意分享平台(WebSocket即时通讯、协同过滤算法、Echarts图形化分析)
  • 依托AR远程协助,沟通协作,高效流畅
  • 七、SpringBoot工程日志设置
  • [前端算法]动态规划
  • 【保姆级教程】CentOS 7 部署 FastDFS 全流程(避坑指南)
  • 【Docker】安装kafka案例
  • 深入解析 Spring IOC 容器在 Web 环境中的启动机制
  • ActiveReports 19.1 Crack
  • 新手向:Python条件语句(if-elif-else)使用指南
  • 初识HTML
  • 云原生俱乐部-k8s知识点归纳(1)
  • AI 编程实践:用 Trae 快速开发 HTML 贪吃蛇游戏
  • 游戏行业DevOps实践:维塔士集团基于Atlassian工具与龙智服务构建全球化游戏开发协作平台
  • LLM 中 语音编码与文本embeding的本质区别
  • 网络流初步
  • 版本更新!FairGuard-Mac加固工具已上线!
  • 【Unity3D实例-功能-移动】角色行走和奔跑的相互切换
  • Unity2022 + URP + Highlight plus V21配置和使用
  • Linux下使用Samba 客户端访问 Samba 服务器的配置(Ubuntu Debian)
  • 一颗TTS语音芯片给产品增加智能语音播报能力
  • 【无标题】卷轴屏手机前瞻:三星/京东方柔性屏耐久性测试进展
  • python自学笔记8 二维和三维可视化
  • 【深度学习】深度学习基础概念与初识PyTorch
  • 【C#补全计划】泛型约束
  • 从0开始的中后台管理系统-7(订单列表功能实现,调用百度地图打点以及轨迹图动态展示)
  • 数据结构--------堆
  • 18.14 全量微调实战手册:7大核心配置提升工业级模型训练效率
  • 阿里云RDS SQL Server实例之间数据库迁移方案
  • 通信算法之313:FPGA中实现滑动相关消耗DSP资源及7045/7035的乘法器资源
  • 工具栏扩展应用接入说明