当前位置: 首页 > news >正文

图像读取与模型保存--基于NWPU-RESISC45数据集的图像二分类实战

今天有点慢了,对吧,哎,上午睡了个懒觉,下午玩了一下。晚上争取再发一篇吧
之前是用torchvision里面的数据集,这一篇我们用自己的数据集来跑深度学习项目,并且保存模型

codes are here

文章目录

  • PyTorch图像分类实战笔记 🚀
    • 1. 数据准备 📊
      • 1.1 导入必要的库
      • 1.2 加载图片数据集
      • 1.3 利用`torchvision.datasets.ImageLoder`类创建Dataset和DataLoader
      • 1.4 数据可视化
    • 2. 构建CNN模型 🧠
      • 2.1 定义网络结构
      • 2.2 初始化模型
      • 2.3 定义损失函数和优化器
    • 3. 训练和评估 🔥
      • 3.1 训练函数
      • 3.2 测试函数
      • 3.3 训练循环
      • 3.4 可视化训练过程
    • ==4. 模型保存与加载 💾==`model.state_dict()`
      • 4.1 保存和加载模型权重
      • 4.2 保存和恢复检查点
      • 4.3 保存最优模型
    • 5. 关键点总结 🎯

PyTorch图像分类实战笔记 🚀

本笔记记录了使用PyTorch进行图像分类的完整流程,从数据加载到模型训练,再到模型保存和优化。让我们一步步来看!

1. 数据准备 📊

下数据集下到自家的了
数据集在这里下载
划分测试集和训练集后整理成如下结构
下载的数据集可以直接用,是划分好的
在这里插入图片描述

1.1 导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

1.2 加载图片数据集

使用ImageFolder类加载图像数据集,并定义数据预处理流程:

ImageFolder类是PyTorch中用于加载图像数据集的类,它可以将图像文件按照类别分成不同的子目录,
并将每个子目录中的图像文件都加载到一个类别的标签中。
最重要的参数就是文件夹路径,它指定了图像数据集的位置。

train_dir = r'D:\my_all_learning\deeplearning\2_class\train' #这里改成你们自己的地址
test_dir = r'D:\my_all_learning\deeplearning\2_class\test'# 还有一个重要的参数transform,它用于对图像进行预处理,比如裁剪、缩放、旋转、归一化等。
# 这里我们使用Compose函数将ToTensor和Normalize两个预处理操作组合起来。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
])

1.3 利用torchvision.datasets.ImageLoder类创建Dataset和DataLoader

train_ds = torchvision.datasets.ImageFolder(train_dir, transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=transform)BATCHSIZE = 16
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BATCHSIZE, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=BATCHSIZE)

1.4 数据可视化

# 查看图像
imgs, labels = next(iter(train_dl))
print(imgs.shape, labels.shape)  # 输出: torch.Size([16, 3, 256, 256]) torch.Size([16])# 显示单张图片
img = imgs[0]
plt.title(labels[0].item())
plt.imshow(img.permute(1,2,0))
plt.show()# 显示前6张图片
plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):img = (img.permute(1,2,0).numpy()+1)/2plt.subplot(2,3,i+1)plt.title(id_to_class.get(label.item()))plt.imshow(img)
plt.show()

2. 构建CNN模型 🧠

2.1 定义网络结构

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)  # 256-2=254self.pool = nn.MaxPool2d(2, 2)    # 127self.conv2 = nn.Conv2d(16, 32, 3) # 125-2=123self.conv3 = nn.Conv2d(32, 64, 3) # 池化层输出尺寸floor((size+2*padding - kernel_size)/stride) + 1# 所以最后是30*30self.fc1 = nn.Linear(64*30*30, 1024)self.fc2 = nn.Linear(1024, 128)self.fc3 = nn.Linear(128, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = x.view(-1, 64*30*30)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

2.2 初始化模型

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net().to(device)

2.3 定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) #我的lr肯定是调大了,这里建议是0.0005,然后下面的epoch多一点

3. 训练和评估 🔥

3.1 训练函数

(这里的代码跟前面的博客的是一样的)

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, correct = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():correct += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_loss /= num_batchescorrect /= sizereturn train_loss, correct

3.2 测试函数

def test(dataloader, model):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizereturn test_loss, correct

3.3 训练循环

epochs = 20
train_loss, train_acc = [], []
test_loss, test_acc = [], []for epoch in range(epochs):epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)epoch_test_loss, epoch_test_acc = test(test_dl, model)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)template = ("epoch:{:2d}, train_loss:{:5f}, train_acc:{:.1f}%, ""test_loss:{:.5f}, test_acc:{:.1f}%")print(template.format(epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))print("Done")  #不敢再跑了,CPU快干烧了 而且好像过拟合了😂
# 大家跑的时候调一调那些超参数

3.4 可视化训练过程

# 绘制损失曲线
plt.plot(range(1,epochs+1), train_loss, label='train_loss')
plt.plot(range(1,epochs+1), test_loss, label='test_loss')
plt.legend()
plt.show()# 绘制准确率曲线
plt.plot(range(1,epochs+1), train_acc, label='train_acc')
plt.plot(range(1,epochs+1), test_acc, label='test_acc')
plt.legend()
plt.show()

4. 模型保存与加载 💾model.state_dict()

4.1 保存和加载模型权重

# 保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')# 加载模型权重
new_model = Net().to(device)
new_model.load_state_dict(torch.load('model_weights.pth'))

4.2 保存和恢复检查点

# 保存检查点
PATH = "model_checkpoint.pt"
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
}, PATH)# 恢复检查点
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
checkpoint = torch.load(PATH)model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

4.3 保存最优模型

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0for epoch in range(epochs):# ...训练代码...if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model_wts = copy.deepcopy(model.state_dict())model.load_state_dict(best_model_wts)
model.eval()  # 切换到评估模式

5. 关键点总结 🎯

  1. 数据预处理:使用transforms进行标准化处理,这对CNN模型很重要
  2. 模型结构:典型的CNN结构包含卷积层、池化层和全连接层
  3. 训练技巧
    • 使用GPU加速训练(如果有)
    • 监控训练和验证集的损失和准确率
    • 保存最佳模型权重
  4. 模型保存
    • 可以只保存模型权重(.state_dict()
    • 也可以保存整个检查点(包含优化器状态等)
    • 最佳实践是保存验证集上表现最好的模型
http://www.dtcms.com/a/277090.html

相关文章:

  • stm32f103c8t6移植freeRTOS内存不足报错问题的解决办法
  • 浏览器渲染原理与性能优化全解析
  • 快速傅里叶变换(FFT)中的振幅和相位
  • 【计算机网络架构】环型架构简介
  • 在 C# 中调用 Python 脚本:实现跨语言功能集成
  • ADB 调试日志全攻略:如何开启与关闭 `ADB_TRACE` 日志
  • CS课程项目设计1:交互友好的井字棋游戏
  • 详解Linux下多进程与多线程通信(二)
  • 【QT】使用QSS进行界面美化
  • 异或为什么叫异或
  • 【读书笔记】《Effective Modern C++》第3章 Moving to Modern C++
  • Datawhale AI夏令营——基于带货视频评论的用户洞察挑战赛
  • 【PTA数据结构 | C语言版】简单计算器
  • 17.使用DenseNet网络进行Fashion-Mnist分类
  • LabVIEW调用外部DLL
  • 深度学习图像分类数据集—七种树叶识别分类
  • 零基础 “入坑” Java--- 十、继承
  • ARC 03 从Github Action job 到 runner pod
  • PPO(近端策略优化)
  • 华为HarmonyOS 5.0深度解析:跨设备算力池技术白皮书(2025全场景智慧中枢)
  • 【C++】list及其模拟实现
  • C++--List
  • AI交互中的礼貌用语:“谢谢“的效用与代价分析
  • 【操作系统-Day 5】通往内核的唯一桥梁:系统调用 (System Call)
  • MVC 参考手册
  • C++值类别与移动语义
  • linux shell从入门到精通(一)——初识Shell程序
  • opencv中contours的使用
  • Spring Boot RESTful API 设计指南:查询接口规范与最佳实践
  • Docker从环境配置到应用上云的极简路径