当前位置: 首页 > news >正文

MobileNetV3训练自定义数据集并通过C++进行推理模型部署

文章目录

  • 1 前言
  • 2 项目内容详细说明
    • 3.1 训练及模型转换
    • 3.2 模型测试(C++)
  • 3 代码
    • 3.1 train.py
    • 3.2 detect.py
    • 3.3 convert_to_onnx.py
    • 3.4 onnxpredict.py
    • 3.5 C++测试工程
  • 4 资源下载


1 前言

  本项目将实现MobileNetV3的自定义数据集训练,以及推理模型部署。本篇内容采用猫狗数据集,2分类模型。
  C++端预测结果如下所示。
在这里插入图片描述

2 项目内容详细说明

3.1 训练及模型转换

  在python环境下进行模型的训练与格式转换。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
  (1)train.py 实现模型训练;
  (2)detect.py 通过pt模型验证推理;
  (3)convert_to_onnx.py 将pt模型转换成onnx模型;
  (4)onnxpredict.py 通过onnx模型推理验证;

3.2 模型测试(C++)

  C++(Qt)环境下的onnx模型验证测试工程。
在这里插入图片描述

3 代码

3.1 train.py

  train.py实现如下所示。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import os
import time
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 数据增强和归一化
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
}# 数据集路径
data_dir = '/media/ai/5c45cbac-396a-4328-b602-e47bc899eb89/ai/DX/ZheDang_MobileNetV3/datasets'
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')# 创建数据集
image_datasets = {'train': datasets.ImageFolder(train_dir, data_transforms['train']),'val': datasets.ImageFolder(val_dir, data_transforms['val'])
}# 创建数据加载器
dataloaders = {'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=4),'val': DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=4)
}# 获取数据集信息
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
num_classes = len(class_names)print(f"类别: {class_names}")
print(f"训练集大小: {dataset_sizes['train']}")
print(f"验证集大小: {dataset_sizes['val']}")# 加载预训练的MobileNetV3模型
def create_model(num_classes=2):model = models.mobilenet_v3_small(pretrained=True)# 修改最后的分类层num_features = model.classifier[3].in_featuresmodel.classifier[3] = nn.Linear(num_features, num_classes)return model# 创建模型
model = create_model(num_classes=num_classes)
model = model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 训练函数
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0# 记录训练过程train_loss_history = []val_loss_history = []train_acc_history = []val_acc_history = []for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# 每个epoch都有训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train()  # 训练模式else:model.eval()  # 评估模式running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条dataloader = dataloaders[phase]pbar = tqdm(dataloader, desc=f'{phase} Epoch {epoch}')# 迭代数据for inputs, labels in pbar:inputs = inputs.to(device)labels = labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 反向传播 + 优化(仅在训练阶段)if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)# 更新进度条pbar.set_postfix({'loss': f'{loss.item():.4f}','acc': f'{torch.sum(preds == labels.data).item() / inputs.size(0):.4f}'})if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 记录历史if phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc.cpu().numpy())else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc.cpu().numpy())# 深度复制最佳模型if phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint(f'训练完成于 {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'最佳验证准确率: {best_acc:.4f}')# 加载最佳模型权重model.load_state_dict(best_model_wts)return model, train_loss_history, val_loss_history, train_acc_history, val_acc_history# 开始训练
print("开始训练模型...")
num_epochs = 100
model, train_loss, val_loss, train_acc, val_acc = train_model(model, criterion, optimizer, scheduler, num_epochs=num_epochs
)# 保存模型
torch.save(model.state_dict(), 'mobileNetV3_best_model.pth')
print("模型已保存为 mobileNetV3_best_model.pth")# 绘制训练曲线
plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)
plt.plot(train_loss, label='Train Loss')
plt.plot(val_loss, label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(train_acc, label='Train Accuracy')
plt.plot(val_acc, label='Val Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()plt.tight_layout()
plt.savefig('training_curves.png')
plt.show()# 测试函数
def test_model(model, dataloader):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in dataloader['val']:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'测试准确率: {accuracy:.2f}%')return accuracy# 测试模型
print("测试模型性能...")
test_accuracy = test_model(model, dataloaders)
print(f"最终测试准确率: {test_accuracy:.2f}%")

3.2 detect.py

predict_image方法具体实现如下所示。

    def predict_image(self, image_path):"""预测单张图像"""try:# 加载和预处理图像image = Image.open(image_path).convert('RGB')input_tensor = self.transform(image).unsqueeze(0).to(self.device)# 预测with torch.no_grad():outputs = self.model(input_tensor)probabilities = torch.softmax(outputs, dim=1)confidence, predicted = torch.max(probabilities, 1)# 获取结果class_name = self.class_names[predicted.item()]confidence_score = confidence.item()return class_name, confidence_scoreexcept Exception as e:print(f"预测时出错: {e}")return None, None

3.3 convert_to_onnx.py

  convert_to_onnx.py实现见第4章。

3.4 onnxpredict.py

  onnxpredict.py实现见第4章。

3.5 C++测试工程

  C++实现工程实现见第4章。

4 资源下载

  本案例中涉及到的所有代码请到此处下载。

http://www.dtcms.com/a/391492.html

相关文章:

  • nvshmem源码学习(一)ibgda视角的整体流程
  • Redis群集的三种模式
  • 鸿蒙(南向/北向)
  • Spring IoCDI 快速入门
  • MySQL的C语言驱动核心——`mysql_real_connect()` 函数
  • C++线程池学习 Day06
  • React 样式CSS的定义 多种定义方式 前端基础
  • react+anddesign组件Tabs实现后台管理系统自定义页签头
  • Midscene 低代码实现Android自动化
  • ADB使用指南
  • FunCaptcha如何查找sitekey参数
  • 大模型如何让机器人实现“从冰箱里拿一瓶可乐”?
  • Python实现液体蒸发优化算法 (Evaporation Rate Water Cycle Algorithm, ER-WCA)(附完整代码)
  • MySQL 数据库的「超级钥匙」—`mysql_real_connect`
  • LeetCode 每日一题 3484. 设计电子表格
  • RAGAS深度解析:引领RAG评估新时代的开源技术革命
  • aave v3.4 利率计算详解
  • rook-ceph CRD资源配置时效问题
  • MySQL学习笔记-进阶篇
  • Rust 关键字
  • 排版使用latex排版还是word排版更容易通过mdpi remote sensing的审稿?
  • Qt QML ToolTip弹出方向控制问题探讨
  • [Windows] PDFQFZ(PDF加盖骑缝章) v1.31
  • 四网络层IP-子网掩码-路由表-真题
  • 安装QT6.9.2
  • 使用 NodePort
  • IP6163至为芯具备MPPT硬件算法的太阳能光伏降压DC-DC芯片
  • 从“道生一”理念看宇宙规律与现代科技之关联
  • CKS-CN 考试知识点分享(9) 关闭API凭据自动挂载
  • 初次接触MCP