当前位置: 首页 > news >正文

pyTorch框架-迁移学习-实现四种天气图片多分类问题

目录

1.导包

 2.加载原数据、创建训练与测试目录路径

 3.用transforms.Compose、torchvision.datasets.ImageFolder数据预处理

 4.加载预训练好的模型

5.固定与修改预训练模型的参数 

6.将模型拷到GPU上 

7.定义优化器与损失函数 

8.定义训练过程 

9.测试运行 

10.测试结果画图可视化 

10.1.训练与测试的损失函数对比

         10.2.训练与测试的准确率对比 

 11.模型保存


 

#迁移学习的常用三种模式
#(1)把别人训练好的模型拿过来用,删掉输出层,加上自己的一层(只有自己新加的这一层参数可训练,其他层的参数固定),形成新的网络
#(2)。。。同上。。。,后面几层的参数都能训练
#(3)。。。同上。。。,全部层的参数都能训练

#训练参数越多,训练速度越慢
#迁移学习:就是基于别人已经训练好的模型框架与参数,稍加改动进行训练

 

1.导包

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms

 2.加载原数据、创建训练与测试目录路径

base_dir = './dataset'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')

 3.用transforms.Compose、torchvision.datasets.ImageFolder数据预处理

transform = transforms.Compose([
    # 统一缩放到96 * 96
    transforms.Resize((96, 96)),
    transforms.ToTensor(),  #此代码的三个作用:(1)把数据变成0-1之间;(2)变成tensor数据;(3)把通道数换到长和宽的前面
    # 正则化  也是标准化处理
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_ds = torchvision.datasets.ImageFolder(train_dir, transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=transform)

#创建dataloader
batch_size = 32
#参数drop_last=True 表示 最后剩余的数据若不满足一批数据的数量,直接删掉
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

 4.加载预训练好的模型

# 加载预训练好的模型
#加bn层2015年提出的,vgg模型是2015年之前提出的,vgg16_bn模型是2015年之后提出的
#参数pretrained=True表示 使用加载模型的 已训练好的模型
#第一次跑 会下载
model = torchvision.models.vgg16(pretrained=True)
model
#avgpool : 平均值池化
#(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))  : 平均值池化 + flatten()数据展平
#可以通过features来索引 获取 每一层的参数
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation&#

相关文章:

  • 【Windows批处理】命令入门详解
  • Rust 2024介绍 | 开发环境搭建详细教程(rust 1.85.0)
  • 《Glance:一站式聚合信息,告别浏览器切换烦恼》
  • 国产芯片解析:龙讯USB Type-C/DP Transmitter多场景覆盖,定义高速互联新标杆
  • 21.OpenCV获取图像轮廓信息
  • 【js逆向】某日番动漫网视频地址解密
  • 车辆监控平台技术标准解析
  • Bert论文解析
  • 2019 CCF CSP-S2.树的重心
  • Linux驱动学习笔记(七)
  • IDEA加载项目时依赖无法更新
  • Visual Studio 2022 QT5.14.2 新建项目无法打开QT的ui文件,出现闪退情况
  • Headscale-Admin-Pro
  • Mysql 概念
  • 如何在大型项目中组织和管理 Vue 3 Hooks?
  • 如何让 -webkit-slider-thumb 生效
  • 火语言RPA--Sqlite-执行SQL
  • DAPP实战篇:规划下我们的开发线路
  • Jupyter notebook定制字体
  • 2025-04-06 Unity Editor 实践 1 —— Editor 窗体框架
  • 门户网站是什么意思啊/360优化大师下载
  • 二级域名建站/北京网站优化指导
  • 如何制作图片配文字/郑州seo关键词优化公司
  • 东阿做网站多少钱/网店代运营和推广销售
  • 比较冷门的视频网站做搬运/关键词排名优化公司成都
  • 《网站开发课程设计》设计报告/网站制作费用一览表