当前位置: 首页 > news >正文

pyTorch-迁移学习-学习率衰减-四种天气图片多分类问题

目录

1.导包

 2.加载数据、拼接训练、测试数据的文件夹路径

 3.数据预处理

3.1 transforms.Compose数据转化

3.2分类存储的图片数据创建dataloader 

torchvision.datasets.ImageFolder

torch.utils.data.DataLoader

4.加载预训练好的模型(迁移学习) 

4.1固定、修改预训练好的模型参数 

4.2模型拷到GPU上 

4.3定义优化器与损失函数 

4.4学习率衰减 

4.5定义训练过程 

5.测试运行 

6.可视化 :训练与测试的损失函数、准确率


#学习率衰减属于超参数设。在深度学习中,超参数是指在模型训练之前设定的参数,它们不会在训练过程中自动学习到,而是需要人为设置的。
#学习率衰减是一种动态调整学习率的方法,通过在学习过程中逐渐减少学习率,以提高模型的训练效果和稳定性。
#具体来说,学习率衰减是通过在训练过程中逐渐减小学习率来实现的,这有助于在训练初期快速收敛,同时在后期避免过大的波动,从而更接近最优解 

#学习率衰减有两种方式,随着训练次数的增加,学习率逐渐变小

#学习率衰减,在训练过程中使用,   测试过程中不需要使用

1.导包

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms

import os

 2.加载数据、拼接训练、测试数据的文件夹路径

base_dir = './dataset'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')

 3.数据预处理

3.1 transforms.Compose数据转化

transform = transforms.Compose([
    # 统一缩放到96 * 96
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

3.2分类存储的图片数据创建dataloader 

torchvision.datasets.ImageFolder

torch.utils.data.DataLoader

train_ds = torchvision.datasets.ImageFolder(train_dir, transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=transform)

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

4.加载预训练好的模型(迁移学习) 

# 加载预训练好的模型
model = torchvision.models.vgg16(pretrained=True)

model
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 

相关文章:

  • 基于ElasticSearch的向量检索技术实践
  • 设计模式 四、行为设计模式(1)
  • 基于层次建模与交叉注意力融合的医学视觉问答系统(HiCA-VQA)详解
  • ⑨数据中心-M-LAG技术配置
  • 8.1 公共控件12
  • 【学Rust写CAD】35 alpha_mul_256(alpha256.rs补充方法)
  • Mamba模型
  • 21 天 Python 计划:MySQL 表相关操作
  • #node.js后端项目的部署相关了解
  • 蓝桥杯每日刷题c++
  • 第4课:多智能体通信协议优化
  • 【区块链安全 | 第三十二篇】内联汇编
  • 13. C++入门基础***
  • 数据库架构
  • 双指针(5)—复写零
  • 层归一化详解及在 Stable Diffusion 中的应用分析
  • AI烘焙大赛中的算法:理解PPO、GRPO与DPO最简单的方式
  • 类和对象(下篇)(详解)
  • nginx中的try_files指令
  • UML组件图
  • 建筑工程网站建站方案/关键词优化步骤简短
  • 成都疫情解封最新消息/seo搜索排名优化方法
  • 石家庄网站建设智美/品牌营销策划机构
  • 郑州网站建设公司e00/视频号最新动作
  • html5高端装修公司网站源码/上海网络营销推广外包
  • 网站跳转怎么做360/友情链接怎么连