当前位置: 首页 > news >正文

Torch -- 卷积学习day4 -- 完整项目流程

完整项目流程总结

1. 环境准备与依赖导入

import time
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
import wandb
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import *
import matplotlib.pyplot as plt

2. 数据准备与增强

# 数据增强变换
transform = transforms.Compose([transforms.RandomRotation(45),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])
​
# 测试集变换
transformtest = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])
​
# 数据集加载
train_dataset = CIFAR10(root=datapath,train=True,download=True,transform=transform,
)
​
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,
)

3. 模型构建与初始化

# 获取ResNet18模型并调整全连接层
model = resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features=in_features, out_features=10)
​
# 加载预训练权重(如果有)
if os.path.exists(weightpath):weights_default = torch.load(weightpath)weights_default.pop("fc.weight", None)weights_default.pop("fc.bias", None)new_state_dict = model.state_dict()weights_default_process = {k: v for k, v in weights_default.items() if k in new_state_dict}new_state_dict.update(weights_default_process)model.load_state_dict(new_state_dict)
​
model.to(device)

4. 训练过程

# 初始化训练工具
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
​
# 可视化工具初始化
wandb.init(project="my-qianyi-project", config={...})
write1 = SummaryWriter(log_dir=log_dir)
write1.add_graph(model, input_to_model=torch.randn(1, 3, 32, 32).to(device))
​
# 训练循环
for epoch in range(epochs):model.train()# 训练代码...torch.save(model.state_dict(), weightpath)

5. 验证与评估

# 加载最佳模型进行验证
model.load_state_dict(torch.load(weightpath))
model.eval()
​
# 验证过程
# 保存预测结果到CSV
# 生成分类报告和混淆矩阵

6. 模型应

# 加载模型进行推理
def predict_image(image_path):# 图像预处理# 模型预测# 返回结果

7. 模型移植与部署

7.1 模型转换(PyTorch → ONNX/)

python

# 转换为ONNX格式
def convert_to_onnx(model, input_size, onnx_path):model.eval()dummy_input = torch.randn(1, *input_size).to(device)torch.onnx.export(model,dummy_input,onnx_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f"Model converted to ONNX and saved to {onnx_path}")
​
# 使用示例
convert_to_onnx(model, (3, 32, 32), "model.onnx")
7.2 模型量化(减小模型大小,加速推理)

python

# 动态量化
def quantize_model(model):quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)return quantized_model
​
# 使用示例
quantized_model = quantize_model(model)
torch.save(quantized_model.state_dict(), "quantized_model.pth")
7.3 减少参数数量
# 简单的权重剪枝
def prune_model(model, pruning_percentage=0.2):parameters_to_prune = []for name, module in model.named_modules():if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):parameters_to_prune.append((module, 'weight'))torch.nn.utils.prune.global_unstructured(parameters_to_prune,pruning_method=torch.nn.utils.prune.L1Unstructured,amount=pruning_percentage,)return model
​
# 使用示例
pruned_model = prune_model(model)
7.4 移动端部署(使用ONNX Runtime)
# 保存为LibTorch格式(C++可用)
example = torch.rand(1, 3, 32, 32).to(device)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
7.5 Web部署(使用ONNX.js)
# 首先转换为ONNX,然后使用ONNX.js在浏览器中运行
# 或者使用第三方工具如https://github.com/onnx/tensorflow-onnx
7.6 边缘设备部署(使用TensorRT、OpenVINO等)
# 使用NVIDIA TensorRT优化(需要先转换为ONNX)
# 或使用Intel OpenVINO工具包

8. 性能监控与优化

# 模型推理速度测试
def benchmark_model(model, input_size, num_runs=100):model.eval()input_tensor = torch.randn(1, *input_size).to(device)# GPU预热for _ in range(10):_ = model(input_tensor)# 计时start_time = time.time()for _ in range(num_runs):_ = model(input_tensor)end_time = time.time()avg_time = (end_time - start_time) / num_runsfps = 1 / avg_timeprint(f"Average inference time: {avg_time*1000:.2f} ms, FPS: {fps:.2f}")return avg_time, fps
​
# 使用示例
benchmark_model(model, (3, 32, 32))

这个完整的流程涵盖了从数据准备到模型部署的全过程,特别是新增的模型移植部分,提供了将训练好的模型部署到不同平台和设备的方法,这对于实际应用非常重要。

http://www.dtcms.com/a/341303.html

相关文章:

  • python numpy.random的基础教程(附opencv 图片转数组、数组转图片)
  • 3D max制作蝴蝶结详细步骤(新手可跟)♥️
  • 制造业原料仓储混乱?WMS 系统实现物料精准溯源,生产更顺畅_
  • 深度剖析Lua Table的运作方式
  • 透传 Attributes(详细解析)1
  • 服务器内存使用buff/cache的原理
  • Linux-----《Linux系统管理速通:界面切换、远程连接、目录权限与用户管理一网打尽》
  • 以AI技术为核心的变电设备声纹监测装置及方案特色解析
  • AI时代下阿里云基础设施的稳定性架构揭秘
  • 初试Docker Desktop工具
  • 服务器硬件电路设计之 SPI 问答(二):SPI 与 I2C 的特性博弈及多从机设计之道
  • Java ReentrantLock 核心用法
  • 算法提升树形数据结构-(线段树)
  • RAG拓展、变体、增强版(二)
  • Django管理后台结合剪映实现课件视频生成应用
  • SpringBoot+Vue打造动漫活动预约系统----后端
  • BM25 系列检索算法
  • Python Day32 JavaScript 数组与对象核心知识点整理
  • 用 Go 库 urfave/cli 轻松构建命令行程序
  • Linux上安装多个JDK版本,需要配置环境变量吗
  • STM32存储结构
  • Vue3 结合 html2canvas 生成图片
  • GISBox工具:FBX到3DTiles文件转换指南
  • SpringBoot - 公共字段自动填充的6种方案
  • 使用安卓平板,通过USB数据线(而不是Wi-Fi)来控制电脑(版本1)
  • Mac编译Android AOSP
  • Vue2+Vue3前端开发_Day3
  • vue3中,如何解决数字精度问题(big.js的使用)
  • 计算机毕设Spark项目实战:基于大数据技术的就业数据分析系统Django+Vue开发指南
  • SQL count(*)与 sum 区别