当前位置: 首页 > news >正文

深度学习————模型保存与部署

第一部分:模型保存基础

什么是模型保存?

当你训练好一个深度学习模型后,它会拥有“学习到的参数”,这些参数(权重、偏置等)构成了模型的“知识”。如果不保存这些参数,那么训练好的模型在关闭程序后就会丢失。

所以,模型保存就是将训练好的参数(或整个模型)保存到磁盘上,供之后加载使用或部署


 两种主要保存方式(以 PyTorch 为例)

  1. 保存模型参数(推荐)

    • 只保存模型的状态字典(state_dict),这是最推荐的方法。

    • 更轻便,适合部署、版本管理。

    • 加载时需要重新构建模型结构,然后加载参数。

  2. 保存整个模型结构与参数(不推荐)

    • 使用 torch.save(model) 直接保存整个模型对象。

    • 不可跨 Python 版本或环境,不利于调试与迁移。


 常见保存格式

框架推荐保存格式说明
PyTorch.pth / .pt.pth 一般用于 state_dict
TensorFlow.ckpt / .pb / SavedModelSavedModel 用于部署
ONNX.onnx便于跨框架、跨平台部署


保存路径与命名建议

  • 路径统一、版本可控,如:

    checkpoints/
    ├── model_v1_2025-05-01.pth
    ├── model_best_val.pth
    └── model_latest.pth
    
  • 可使用时间戳 + 性能指标命名,便于后续追踪:

    model_acc87.4_epoch15.pth
    

 版本管理建议

  • 使用日志系统(如 TensorBoard、WandB)记录对应版本表现。

  • 每次训练完成后保存多个模型:如最优验证集模型(best)、最后模型(last)。

  • 大项目建议结合 Git 和 DVC(Data Version Control)管理模型文件。

第二部分:PyTorch 中的模型保存与加载实战

PyTorch 提供了非常灵活和强大的模型保存与加载机制,主要通过 state_dict(模型参数字典)进行操作。下面我们详细讲解每一步并提供示例代码。


 一、什么是 state_dict

state_dict 是一个 Python 字典,保存了模型中每一层的参数(权重和偏置等)。它的格式大致如下:

{'layer1.weight': tensor(...),'layer1.bias': tensor(...),...
}

每个模块(如 nn.Linear, nn.Conv2d)都将其参数注册在 state_dict 中。


🔹 二、保存模型参数(推荐)

保存代码:

import torch# 假设你有一个模型实例 model
torch.save(model.state_dict(), 'model.pth')

注意事项:

  • model.pth 只是文件名,扩展名可以是 .pt.pth,没有区别。

  • 只保存参数,不包含模型结构,因此加载时需要手动定义结构。


 三、加载模型参数

加载步骤分两步走:

  1. 重新定义模型结构;

  2. 加载参数到模型中。

    # 1. 定义模型结构(必须与保存时一致)
    model = MyModel()# 2. 加载参数
    model.load_state_dict(torch.load('model.pth'))# 3. 切换到评估模式(部署时必须)
    model.eval()
    

🔹 四、保存整个模型(不推荐)

torch.save(model, 'entire_model.pth')

然后加载:

model = torch.load('entire_model.pth')
model.eval()

缺点:

  • 依赖于模型的类定义和 Python 环境;

  • 一旦结构变动,加载可能出错;

  • 不适合跨平台部署。


五、训练状态一起保存(含优化器)

训练中断后可恢复继续训练,需要同时保存模型和优化器状态。

# 保存
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,
}, 'checkpoint.pth')

加载时:

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

 六、保存和加载到指定设备(如 GPU)

# 保存时无关设备
torch.save(model.state_dict(), 'model.pth')# 加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))# 加载到 GPU
device = torch.device('cuda')
model.load_state_dict(torch.load('model.pth', map_location=device))

七、完整示例(含模型结构)

import torch
import torch.nn as nn# 模型定义
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 初始化模型与优化器
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())# 保存
torch.save({'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
}, 'model_checkpoint.pth')# 加载
model2 = MyModel()
optimizer2 = torch.optim.Adam(model2.parameters())
checkpoint = torch.load('model_checkpoint.pth')
model2.load_state_dict(checkpoint['model_state_dict'])
optimizer2.load_state_dict(checkpoint['optimizer_state_dict'])

 

第三部分:模型部署基础概念

模型训练完成后,并不是终点。部署模型的目的,是将它放到现实世界中,为用户或系统提供服务。比如:

  • 智能客服系统:用户发送一条消息,模型给出回复;

  • 医疗图像诊断:上传CT图像,模型输出预测结果;

  • 教学系统:上传作业照片,模型识别题目并自动评分。

本部分将介绍部署的核心概念及常见方式。


 一、为什么需要部署?

目标说明
模型服务化把训练好的模型变成一个可以实时调用的服务(如 API)
多用户访问支持多个用户、多个终端访问(Web、App等)
实时推理对输入进行实时预测,如语音识别、图像识别
系统集成将模型集成进现有的软件系统、产品或平台
规模扩展支持大规模并发调用,进行推理加速、负载均衡等

 二、部署分类与对比

我们按部署场景将常见方式进行分类总结:

部署方式描述适合场景优点缺点
本地部署模型运行在本地电脑或服务器上开发测试、小项目简单易操作,无需网络不易扩展,不适合多人使用
Web 服务部署封装成 HTTP API / Web UI实际产品,后台系统可远程访问,适合用户使用部署较复杂,对安全性有要求
云端部署部署到云服务器(如阿里云、AWS)大型项目、商业部署可弹性伸缩,服务稳定成本高,涉及 DevOps 知识
移动端部署模型打包到手机或嵌入式设备移动AI、边缘设备离线可用,低延迟受限于算力、平台兼容性
服务器集群部署结合容器与负载均衡器部署多个模型高并发、高可用场景可自动扩缩、容错性好依赖 Docker/K8s,配置复杂

 三、部署方式常用工具和框架

场景工具/平台示例简述
本地部署Flask、Gradio、Streamlit简单封装模型为 API 或 Web 界面
Web 后端部署FastAPI、Flask + Gunicorn可高性能提供 REST 接口
云服务部署HuggingFace Spaces、阿里云 ECS快速上线,适合演示和产品原型
模型导出与推理加速TorchScript、ONNX、TensorRT优化模型结构,提高推理速度
多模型管理MLflow、TorchServe、NVIDIA Triton模型托管、版本管理与部署平台

 四、常见部署架构图示意(文字版)

用户 -> 浏览器 / App|V[ Web 前端 ]        ←(Gradio / React + Flask 等)|V[ Web 后端 API ]|V[ 推理服务(模型加载) ]|V[ 模型参数 / 权重文件 ]

 五、从训练到部署流程总览

  1. 训练模型:在本地或服务器完成训练;

  2. 保存模型:保存为 .pth.onnx 文件;

  3. 封装接口:使用 Flask / Gradio / FastAPI 编写服务;

  4. 构建前端(可选):使用 HTML / React / Gradio 交互;

  5. 部署上线:本地测试通过后部署到服务器或平台;

  6. 用户使用:通过网页、App 等方式访问部署的服务。

第四部分:模型导出为部署格式(TorchScript 和 ONNX)

训练好的 PyTorch 模型需要导出成标准格式,才能跨平台、跨框架、高效地部署。TorchScriptONNX 是 PyTorch 中最常用的导出格式。

本部分将详细讲解两者的概念、区别、导出方式及使用场景。


 一、为什么要导出模型?

虽然 .pth 格式在 PyTorch 内部很方便使用,但部署时常常需要:

  • 加快推理速度

  • 在没有 Python 的环境中运行

  • 与其他框架(如 TensorFlow、C++、移动端)兼容

  • 更稳定、更可控的模型格式

这时就需要导出为中间格式,如 TorchScript 或 ONNX。


 TorchScript 模型导出

 什么是 TorchScript?

TorchScript 是 PyTorch 的一个中间表示,它允许模型以静态图的形式保存并运行。这使得:

  • 可脱离 Python 环境运行

  • 可通过 C++ API 部署

  • 支持推理优化(如 torch.jit.optimize_for_inference


 TorchScript 两种转换方式

1. 追踪法(Tracing)

适合无条件分支的模型。

import torch# 假设 model 是你训练好的模型
model.eval()example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)# 保存
traced_model.save("model_traced.pt")
2. 脚本法(Scripting)

适合包含 if/else、循环等逻辑的模型。

scripted_model = torch.jit.script(model)
scripted_model.save("model_scripted.pt")

💡 TorchScript 加载与推理

import torchmodel = torch.jit.load("model_traced.pt")
model.eval()output = model(torch.randn(1, 3, 224, 224))

ONNX 模型导出

什么是 ONNX?

ONNX(Open Neural Network Exchange)是一种通用模型格式,由微软和 Facebook 发起,支持多种深度学习框架,如:

  • PyTorch

  • TensorFlow

  • MXNet

  • OpenCV DNN

  • ONNX Runtime

  • TensorRT


PyTorch 转 ONNX 示例

import torchmodel.eval()
dummy_input = torch.randn(1, 3, 224, 224)torch.onnx.export(model, dummy_input,"model.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},opset_version=11
)

ONNX 模型验证

你可以用 onnx 包验证导出是否成功:

import onnxonnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)  # 抛出异常说明有问题

 推理:ONNX Runtime

import onnxruntime as ort
import numpy as npsession = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)output = session.run(None, {input_name: input_data})

 TorchScript vs ONNX 对比总结

特性TorchScriptONNX
支持框架PyTorch 本身 + C++多框架(TensorRT、ONNX RT等)
性能优化支持是(官方提供优化接口)是(ONNX Runtime / TensorRT)
转换复杂度简单稍复杂,需要注意版本/OP集
支持 Python 控制流否(静态图模型)
移植性中(依赖 PyTorch 环境)强(适合工业部署)
推荐场景内部 PyTorch 部署跨平台、商业部署

第五部分:模型部署方式详解(Gradio、Flask、ONNX Runtime等)

在本部分,我们将从实用角度出发,逐一讲解几种常用部署方式,并配合完整代码模板,帮助你快速上手部署一个推理服务。


 方式一:使用 Gradio 快速构建 Web 界面

Gradio 是一个非常流行的 Python 库,用于快速构建交互式 Web 应用,适合演示、测试和初步上线。


 1. 安装 Gradio
pip install gradio

2. 代码示例:图像分类模型部署(TorchScript)
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image# 加载 TorchScript 模型
model = torch.jit.load("model_traced.pt")
model.eval()# 图像预处理函数
preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()
])# 推理函数
def predict(img):img = preprocess(img).unsqueeze(0)with torch.no_grad():output = model(img)probs = torch.nn.functional.softmax(output[0], dim=0)return {f"Class {i}": float(p) for i, p in enumerate(probs)}# 创建界面
iface = gr.Interface(fn=predict, inputs="image", outputs="label")
iface.launch()

 启动后会自动打开浏览器访问地址,如:http://127.0.0.1:7860


方式二:使用 Flask 构建 RESTful 接口(API)

Flask 是 Python 中常用的 Web 框架,可以把模型封装成一个 HTTP 接口供前端或其他服务调用。


 1. 安装 Flask
pip install flask

 2. API 接口部署模板(适合 ONNX)
from flask import Flask, request, jsonify
import onnxruntime as ort
import numpy as npapp = Flask(__name__)# 初始化 ONNX 推理器
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name@app.route("/predict", methods=["POST"])
def predict():data = request.json["input"]  # 输入是嵌套 listinput_array = np.array(data).astype(np.float32)output = session.run(None, {input_name: input_array})return jsonify({"output": output[0].tolist()})if __name__ == '__main__':app.run(debug=True, port=5000)

前端可以通过 POST 请求向 /predict 发送数据,返回 JSON 格式的模型输出。


 方式三:部署到 HuggingFace Spaces(在线部署平台)

HuggingFace 提供免费的部署平台,支持 Gradio/Streamlit 应用的在线托管。


步骤:
  1. 在 https://huggingface.co/spaces 创建一个新的 Space;

  2. 选择 Gradio 模板;

  3. 上传你的代码文件(如 app.py)和 requirements.txt

  4. 提交后等待构建,即可访问。

示例

gradio
torch
torchvision

 方式四:ONNX Runtime + FastAPI + Docker(工业部署)

适合构建高性能、可扩展的 API 服务。

  • 使用 FastAPI 替代 Flask(性能更高);

  • 使用 Docker 打包(环境一致性);

  • 使用 ONNX Runtime(加速推理);

 若你感兴趣,我可以提供该方式的完整项目结构与部署脚本。


 常见部署注意事项

问题/注意点说明
模型文件太大可用 torch.quantization 压缩模型
GPU/CPU 版本不一致部署前明确目标环境是否支持 CUDA
接口响应慢FastAPI + Uvicorn 替代 Flask
高并发请求处理困难使用 Gunicorn 或 Docker + Kubernetes
数据预处理慢把预处理逻辑也放在服务端完成
服务崩溃/异常退出加入异常处理与日志记录系统

 

第六部分:高级部署与优化技巧(模型压缩、推理加速、Docker 打包、前端集成)

当你完成了模型部署的基本流程,进一步优化部署效果(速度、稳定性、易用性)就很关键了。下面我们从多个方面介绍进阶技巧。


 一、模型压缩与推理加速

部署模型时,常常遇到模型太大、推理太慢、占用资源高等问题。可以通过以下几种方式进行模型压缩推理加速


1. 模型量化(Quantization)

将浮点数权重压缩成更小的数据类型(如 float16int8),大幅降低模型大小和推理耗时。

静态量化(Post-training)示例:

import torch.quantizationmodel_fp32 = ...  # 已训练模型
model_fp32.eval()# 准备量化配置
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_fp32, inplace=True)# 运行一次推理(用于收集统计信息)
_ = model_fp32(torch.randn(1, 3, 224, 224))# 转换为量化模型
quantized_model = torch.quantization.convert(model_fp32, inplace=False)# 保存
torch.jit.script(quantized_model).save("model_quant.pt")

 注意:部分模块(如 BatchNorm)不支持直接量化,需使用 QuantStub 包装。


2. 使用 Torch-TensorRT(NVIDIA GPU 加速)

Torch-TensorRT 是 NVIDIA 供的一个库,可将 TorchScript 模型转换为 TensorRT 引擎

pip install torch-tensorrt -U

简单使用:

import torch_tensorrttrt_model = torch_tensorrt.compile(model, inputs=[torch.randn(1, 3, 224, 224).to("cuda")], enabled_precisions={torch.float16})

✅ 二、Docker 化部署(推荐生产环境使用)

Docker 可以把你的服务打包成镜像,确保环境一致性、可移植性。


1. 创建项目目录结构
deploy_app/
├── app.py               # Flask / Gradio 应用
├── model.onnx           # 导出的模型
├── requirements.txt     # 所需 Python 包
└── Dockerfile           # Docker 构建脚本

2. Dockerfile 示例
FROM python:3.10WORKDIR /appCOPY requirements.txt .
RUN pip install -r requirements.txtCOPY . .CMD ["python", "app.py"]

3. 构建并运行容器
docker build -t my_model_app .
docker run -p 5000:5000 my_model_app

 若部署到云端(如阿里云、AWS),推荐结合 Nginx 反向代理与容器编排(如 docker-compose 或 Kubernetes)。


 三、前端集成与美化建议

技术优点示例用途
Gradio快速搭建交互界面原型演示、测试用
Streamlit数据可视化友好图像/表格/图表展示等
HTML + JS适合自定义界面、美化展示嵌入 Web 系统、企业平台
React/Vue高度定制、适合商用产品构建完整 Web 应用

 四、完整部署案例:PyTorch → ONNX → Gradio → Docker → HuggingFace Spaces


总结与建议

部分内容概览
第一部分模型保存格式(权重、结构、完整模型)
第二部分加载与恢复模型的多种方式
第三部分部署的基本概念与分类
第四部分模型导出为 TorchScript / ONNX
第五部分使用 Gradio / Flask / ONNX Runtime 部署
第六部分模型压缩、推理加速、Docker 化、高级部署建议

相关文章:

  • Word2Vec详解
  • IDEA+AI 深度融合:重构高效开发的未来模式
  • Unity实用技能-UI定位总结
  • 从秒开到丝滑体验!WebAssembly助力ZKmall商城重构 B2B2C 商城性能基线
  • AI大语言模型评测体系演进与未来展望
  • Python类方法解析:从字节序列重构Vector2d实例
  • 从虚拟仿真到行业实训再到具身智能--华清远见嵌入式物联网人工智能全链路教学方案
  • 物联网简介:万物互联的未来图景
  • 国标GB28181视频平台EasyGBS校园监控方案:多场景应用筑牢安全防线,提升管理效能
  • Windows中PDF TXT Excel Word PPT等Office文件在预览窗格无法预览的终级解决方法大全
  • Kafka 消息堆积与慢消费问题排查及优化实践
  • ALTER COLLATION使用场景
  • 深入解析PyTorch中MultiheadAttention的参数key_padding_mask与attn_mask
  • 分布式与集群:概念、区别与协同
  • disryptor和rabbitmq
  • RabbitMQ-如何选择消息队列?
  • 大语言模型(LLM)如何通过“思考时间”(即推理时的计算资源)提升推理能力
  • Java设计模式之组合模式:从入门到精通(保姆级教程)
  • 【NLP】37. NLP中的众包
  • Better Faster Large Language Models via Multi-token Prediction 原理
  • 欧阳娜娜等20多名艺人被台当局列入重要查核对象,国台办回应
  • 聚焦智能浪潮下的创业突围,“青年草坪创新创业湃对”走进北杨人工智能小镇
  • 读懂城市|成都高新区:打造“人尽其才”的“理想之城”
  • 蔡建忠已任昆山市副市长、市公安局局长
  • 特朗普政府涉税改法案遭众议院预算委员会否决
  • 大陆非遗项目打铁花、英歌舞将在台演出