当前位置: 首页 > news >正文

pyTorch-迁移学习-图片数据增强-四种天气图片的多分类问题

目录

1.导包

 2.加载数据、拼接训练与测试数据的文件夹路径

3数据预处理 

3.1数据增强 

3.2用分类存储的图片数据创建dataloader

4.加载预训练好的模型 (迁移学习)

4.1固定、修改预训练好的模型 

5.将模型拷到GPU上 

6.定义优化器与损失函数 

7.学习率衰减 

8.定义训练过程 

9.运行测试 

10.可视化:训练与测试的损失函数、准确率对比 


 

#模型优化,不一定非要改模型的参数,也可以通过学习率衰减(超参数设置)、数据增强等方法进行优化
#在这个项目中,使用数据增强,减少了过拟合 

1.导包

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms

import os

 2.加载数据、拼接训练与测试数据的文件夹路径

base_dir = './dataset'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')

3数据预处理 

3.1数据增强 

#图片数据增强的常用方法
transforms.RandomCrop    # 随机位置的裁剪 , CenterCrop 中间位置裁剪
transforms.RandomRotation # 随机旋转
transforms.RandomHorizontalFlip() # 水平翻转
transforms.RandomVerticalFlip() # 垂直翻转
transforms.ColorJitter(brightness) # 亮度
transforms.ColorJitter(contrast) # 对比度
transforms.ColorJitter(saturation) # 饱和度
transforms.ColorJitter(hue)   #图像抖动
transforms.RandomGrayscale() # 随机灰度化.

# 数据增强只会加在训练数据上.  不一定使用了数据增强,训练效果就一定好!!!
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),   #原论文中的统一的尺寸参数要求
    transforms.RandomCrop(192),      #从原图中切出来的尺寸大小
    transforms.RandomHorizontalFlip(), #水平翻转
    transforms.RandomVerticalFlip(),   #垂直翻转
    transforms.RandomRotation(0.4),    #随机旋转   0.4是旋转的角度比例
#     transforms.ColorJitter(brightness=0.5),
#     transforms.ColorJitter(contrast=0.5),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
#测试数据不进行图片的数据增强
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

3.2用分类存储的图片数据创建dataloader

train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=test_transform)

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

4.加载预训练好的模型 (迁移学习)

# 加载预训练好的模型
model 

相关文章:

  • 群体智能优化算法-白鲨优化算法(White Shark Optimizer,WSO,含Matlab源代码)
  • JS中的WeakMap
  • 思考 - 操作系统
  • 路由器工作在OSI模型的哪一层?
  • babel-runtime 如何缩小打包体积
  • usbip学习记录
  • 基于springboot微信小程序课堂签到及提问系统(源码+lw+部署文档+讲解),源码可白嫖!
  • 自动提取pdf公式 ➕ 输出 LaTeX
  • C++ 指针类型转换全面解析与最佳实践
  • PyTorch标注工具
  • 【C++】Chapter04<STL部分>:STL标准模板库概要
  • 【团体程序涉及天梯赛】L1~L2实战反思合集(C++)
  • Java并发编程高频面试题
  • Dubbo(41)如何排查Dubbo的服务不可用问题?
  • OpenCV阈值处理详解
  • 企业数据分析何时该放弃Excel?
  • No module named ‘keras.engine‘
  • mysql8.0.29 win64下载
  • SpringCloud的简单介绍
  • Jmeter脚本使用要点记录
  • 央媒:安徽凤阳鼓楼坍塌楼宇部分非文物,系违规复建的“假古董”
  • 【社论】进一步拧紧过紧日子的制度螺栓
  • 纽约市长称墨海军帆船撞桥已致2人死亡,撞桥前船只疑似失去动力
  • 有关“普泽会”,俄官方表示:有可能
  • 小米汽车回应部分SU7前保险杠形变
  • 一个留美学生的思想转向——裘毓麐的《游美闻见录》及其他