当前位置: 首页 > news >正文

pyTorch-迁移学习-图片数据增强-四种天气图片的多分类问题

目录

1.导包

 2.加载数据、拼接训练与测试数据的文件夹路径

3数据预处理 

3.1数据增强 

3.2用分类存储的图片数据创建dataloader

4.加载预训练好的模型 (迁移学习)

4.1固定、修改预训练好的模型 

5.将模型拷到GPU上 

6.定义优化器与损失函数 

7.学习率衰减 

8.定义训练过程 

9.运行测试 

10.可视化:训练与测试的损失函数、准确率对比 


 

#模型优化,不一定非要改模型的参数,也可以通过学习率衰减(超参数设置)、数据增强等方法进行优化
#在这个项目中,使用数据增强,减少了过拟合 

1.导包

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms

import os

 2.加载数据、拼接训练与测试数据的文件夹路径

base_dir = './dataset'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')

3数据预处理 

3.1数据增强 

#图片数据增强的常用方法
transforms.RandomCrop    # 随机位置的裁剪 , CenterCrop 中间位置裁剪
transforms.RandomRotation # 随机旋转
transforms.RandomHorizontalFlip() # 水平翻转
transforms.RandomVerticalFlip() # 垂直翻转
transforms.ColorJitter(brightness) # 亮度
transforms.ColorJitter(contrast) # 对比度
transforms.ColorJitter(saturation) # 饱和度
transforms.ColorJitter(hue)   #图像抖动
transforms.RandomGrayscale() # 随机灰度化.

# 数据增强只会加在训练数据上.  不一定使用了数据增强,训练效果就一定好!!!
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),   #原论文中的统一的尺寸参数要求
    transforms.RandomCrop(192),      #从原图中切出来的尺寸大小
    transforms.RandomHorizontalFlip(), #水平翻转
    transforms.RandomVerticalFlip(),   #垂直翻转
    transforms.RandomRotation(0.4),    #随机旋转   0.4是旋转的角度比例
#     transforms.ColorJitter(brightness=0.5),
#     transforms.ColorJitter(contrast=0.5),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
#测试数据不进行图片的数据增强
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

3.2用分类存储的图片数据创建dataloader

train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=test_transform)

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

4.加载预训练好的模型 (迁移学习)

# 加载预训练好的模型
model 
http://www.dtcms.com/a/117908.html

相关文章:

  • 群体智能优化算法-白鲨优化算法(White Shark Optimizer,WSO,含Matlab源代码)
  • JS中的WeakMap
  • 思考 - 操作系统
  • 路由器工作在OSI模型的哪一层?
  • babel-runtime 如何缩小打包体积
  • usbip学习记录
  • 基于springboot微信小程序课堂签到及提问系统(源码+lw+部署文档+讲解),源码可白嫖!
  • 自动提取pdf公式 ➕ 输出 LaTeX
  • C++ 指针类型转换全面解析与最佳实践
  • PyTorch标注工具
  • 【C++】Chapter04<STL部分>:STL标准模板库概要
  • 【团体程序涉及天梯赛】L1~L2实战反思合集(C++)
  • Java并发编程高频面试题
  • Dubbo(41)如何排查Dubbo的服务不可用问题?
  • OpenCV阈值处理详解
  • 企业数据分析何时该放弃Excel?
  • No module named ‘keras.engine‘
  • mysql8.0.29 win64下载
  • SpringCloud的简单介绍
  • Jmeter脚本使用要点记录
  • volatile关键字用途说明
  • 打印网络内的层名称与特征图大小
  • 数据操作语言
  • 初探:OutSystems的运行原理是什么?
  • R语言赋能气象水文科研:从多维数据处理到学术级可视化
  • Python爬虫HTTP代理使用教程:突破反爬的实战指南
  • 隐私计算的崛起:数据安全的未来守护者
  • ollama+open-webui本地部署自己的模型到d盘+两种open-webui部署方式(详细步骤+大量贴图)
  • obj.name 和 obj[name]的区别?【前端】
  • 【Yonyou-BIP】平台档案删除时报自建应用实体错误