深度学习————模型保存与部署
第一部分:模型保存基础
什么是模型保存?
当你训练好一个深度学习模型后,它会拥有“学习到的参数”,这些参数(权重、偏置等)构成了模型的“知识”。如果不保存这些参数,那么训练好的模型在关闭程序后就会丢失。
所以,模型保存就是将训练好的参数(或整个模型)保存到磁盘上,供之后加载使用或部署。
两种主要保存方式(以 PyTorch 为例)
-
保存模型参数(推荐)
-
只保存模型的状态字典(
state_dict
),这是最推荐的方法。 -
更轻便,适合部署、版本管理。
-
加载时需要重新构建模型结构,然后加载参数。
-
-
保存整个模型结构与参数(不推荐)
-
使用
torch.save(model)
直接保存整个模型对象。 -
不可跨 Python 版本或环境,不利于调试与迁移。
-
常见保存格式
框架 | 推荐保存格式 | 说明 |
---|---|---|
PyTorch | .pth / .pt | .pth 一般用于 state_dict |
TensorFlow | .ckpt / .pb / SavedModel | SavedModel 用于部署 |
ONNX | .onnx | 便于跨框架、跨平台部署 |
保存路径与命名建议
-
路径统一、版本可控,如:
checkpoints/ ├── model_v1_2025-05-01.pth ├── model_best_val.pth └── model_latest.pth
-
可使用时间戳 + 性能指标命名,便于后续追踪:
model_acc87.4_epoch15.pth
版本管理建议
-
使用日志系统(如 TensorBoard、WandB)记录对应版本表现。
-
每次训练完成后保存多个模型:如最优验证集模型(best)、最后模型(last)。
-
大项目建议结合 Git 和 DVC(Data Version Control)管理模型文件。
第二部分:PyTorch 中的模型保存与加载实战
PyTorch 提供了非常灵活和强大的模型保存与加载机制,主要通过 state_dict
(模型参数字典)进行操作。下面我们详细讲解每一步并提供示例代码。
一、什么是 state_dict
?
state_dict
是一个 Python 字典,保存了模型中每一层的参数(权重和偏置等)。它的格式大致如下:
{'layer1.weight': tensor(...),'layer1.bias': tensor(...),...
}
每个模块(如 nn.Linear
, nn.Conv2d
)都将其参数注册在 state_dict
中。
🔹 二、保存模型参数(推荐)
保存代码:
import torch# 假设你有一个模型实例 model
torch.save(model.state_dict(), 'model.pth')
注意事项:
-
model.pth
只是文件名,扩展名可以是.pt
或.pth
,没有区别。 -
只保存参数,不包含模型结构,因此加载时需要手动定义结构。
三、加载模型参数
加载步骤分两步走:
-
重新定义模型结构;
-
加载参数到模型中。
# 1. 定义模型结构(必须与保存时一致) model = MyModel()# 2. 加载参数 model.load_state_dict(torch.load('model.pth'))# 3. 切换到评估模式(部署时必须) model.eval()
🔹 四、保存整个模型(不推荐)
torch.save(model, 'entire_model.pth')
然后加载:
model = torch.load('entire_model.pth')
model.eval()
缺点:
-
依赖于模型的类定义和 Python 环境;
-
一旦结构变动,加载可能出错;
-
不适合跨平台部署。
五、训练状态一起保存(含优化器)
训练中断后可恢复继续训练,需要同时保存模型和优化器状态。
# 保存
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,
}, 'checkpoint.pth')
加载时:
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
六、保存和加载到指定设备(如 GPU)
# 保存时无关设备
torch.save(model.state_dict(), 'model.pth')# 加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))# 加载到 GPU
device = torch.device('cuda')
model.load_state_dict(torch.load('model.pth', map_location=device))
七、完整示例(含模型结构)
import torch
import torch.nn as nn# 模型定义
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 初始化模型与优化器
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())# 保存
torch.save({'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
}, 'model_checkpoint.pth')# 加载
model2 = MyModel()
optimizer2 = torch.optim.Adam(model2.parameters())
checkpoint = torch.load('model_checkpoint.pth')
model2.load_state_dict(checkpoint['model_state_dict'])
optimizer2.load_state_dict(checkpoint['optimizer_state_dict'])
第三部分:模型部署基础概念
模型训练完成后,并不是终点。部署模型的目的,是将它放到现实世界中,为用户或系统提供服务。比如:
-
智能客服系统:用户发送一条消息,模型给出回复;
-
医疗图像诊断:上传CT图像,模型输出预测结果;
-
教学系统:上传作业照片,模型识别题目并自动评分。
本部分将介绍部署的核心概念及常见方式。
一、为什么需要部署?
目标 | 说明 |
---|---|
模型服务化 | 把训练好的模型变成一个可以实时调用的服务(如 API) |
多用户访问 | 支持多个用户、多个终端访问(Web、App等) |
实时推理 | 对输入进行实时预测,如语音识别、图像识别 |
系统集成 | 将模型集成进现有的软件系统、产品或平台 |
规模扩展 | 支持大规模并发调用,进行推理加速、负载均衡等 |
二、部署分类与对比
我们按部署场景将常见方式进行分类总结:
部署方式 | 描述 | 适合场景 | 优点 | 缺点 |
---|---|---|---|---|
本地部署 | 模型运行在本地电脑或服务器上 | 开发测试、小项目 | 简单易操作,无需网络 | 不易扩展,不适合多人使用 |
Web 服务部署 | 封装成 HTTP API / Web UI | 实际产品,后台系统 | 可远程访问,适合用户使用 | 部署较复杂,对安全性有要求 |
云端部署 | 部署到云服务器(如阿里云、AWS) | 大型项目、商业部署 | 可弹性伸缩,服务稳定 | 成本高,涉及 DevOps 知识 |
移动端部署 | 模型打包到手机或嵌入式设备 | 移动AI、边缘设备 | 离线可用,低延迟 | 受限于算力、平台兼容性 |
服务器集群部署 | 结合容器与负载均衡器部署多个模型 | 高并发、高可用场景 | 可自动扩缩、容错性好 | 依赖 Docker/K8s,配置复杂 |
三、部署方式常用工具和框架
场景 | 工具/平台示例 | 简述 |
---|---|---|
本地部署 | Flask、Gradio、Streamlit | 简单封装模型为 API 或 Web 界面 |
Web 后端部署 | FastAPI、Flask + Gunicorn | 可高性能提供 REST 接口 |
云服务部署 | HuggingFace Spaces、阿里云 ECS | 快速上线,适合演示和产品原型 |
模型导出与推理加速 | TorchScript、ONNX、TensorRT | 优化模型结构,提高推理速度 |
多模型管理 | MLflow、TorchServe、NVIDIA Triton | 模型托管、版本管理与部署平台 |
四、常见部署架构图示意(文字版)
用户 -> 浏览器 / App|V[ Web 前端 ] ←(Gradio / React + Flask 等)|V[ Web 后端 API ]|V[ 推理服务(模型加载) ]|V[ 模型参数 / 权重文件 ]
五、从训练到部署流程总览
-
训练模型:在本地或服务器完成训练;
-
保存模型:保存为
.pth
或.onnx
文件; -
封装接口:使用 Flask / Gradio / FastAPI 编写服务;
-
构建前端(可选):使用 HTML / React / Gradio 交互;
-
部署上线:本地测试通过后部署到服务器或平台;
-
用户使用:通过网页、App 等方式访问部署的服务。
第四部分:模型导出为部署格式(TorchScript 和 ONNX)
训练好的 PyTorch 模型需要导出成标准格式,才能跨平台、跨框架、高效地部署。TorchScript 和 ONNX 是 PyTorch 中最常用的导出格式。
本部分将详细讲解两者的概念、区别、导出方式及使用场景。
一、为什么要导出模型?
虽然 .pth
格式在 PyTorch 内部很方便使用,但部署时常常需要:
-
加快推理速度
-
在没有 Python 的环境中运行
-
与其他框架(如 TensorFlow、C++、移动端)兼容
-
更稳定、更可控的模型格式
这时就需要导出为中间格式,如 TorchScript 或 ONNX。
TorchScript 模型导出
什么是 TorchScript?
TorchScript 是 PyTorch 的一个中间表示,它允许模型以静态图的形式保存并运行。这使得:
-
可脱离 Python 环境运行
-
可通过
C++ API
部署 -
支持推理优化(如
torch.jit.optimize_for_inference
)
TorchScript 两种转换方式
1. 追踪法(Tracing)
适合无条件分支的模型。
import torch# 假设 model 是你训练好的模型
model.eval()example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)# 保存
traced_model.save("model_traced.pt")
2. 脚本法(Scripting)
适合包含 if/else、循环等逻辑的模型。
scripted_model = torch.jit.script(model)
scripted_model.save("model_scripted.pt")
💡 TorchScript 加载与推理
import torchmodel = torch.jit.load("model_traced.pt")
model.eval()output = model(torch.randn(1, 3, 224, 224))
ONNX 模型导出
什么是 ONNX?
ONNX(Open Neural Network Exchange)是一种通用模型格式,由微软和 Facebook 发起,支持多种深度学习框架,如:
-
PyTorch
-
TensorFlow
-
MXNet
-
OpenCV DNN
-
ONNX Runtime
-
TensorRT
PyTorch 转 ONNX 示例
import torchmodel.eval()
dummy_input = torch.randn(1, 3, 224, 224)torch.onnx.export(model, dummy_input,"model.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},opset_version=11
)
ONNX 模型验证
你可以用 onnx
包验证导出是否成功:
import onnxonnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model) # 抛出异常说明有问题
推理:ONNX Runtime
import onnxruntime as ort
import numpy as npsession = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)output = session.run(None, {input_name: input_data})
TorchScript vs ONNX 对比总结
特性 | TorchScript | ONNX |
---|---|---|
支持框架 | PyTorch 本身 + C++ | 多框架(TensorRT、ONNX RT等) |
性能优化支持 | 是(官方提供优化接口) | 是(ONNX Runtime / TensorRT) |
转换复杂度 | 简单 | 稍复杂,需要注意版本/OP集 |
支持 Python 控制流 | 是 | 否(静态图模型) |
移植性 | 中(依赖 PyTorch 环境) | 强(适合工业部署) |
推荐场景 | 内部 PyTorch 部署 | 跨平台、商业部署 |
第五部分:模型部署方式详解(Gradio、Flask、ONNX Runtime等)
在本部分,我们将从实用角度出发,逐一讲解几种常用部署方式,并配合完整代码模板,帮助你快速上手部署一个推理服务。
方式一:使用 Gradio 快速构建 Web 界面
Gradio 是一个非常流行的 Python 库,用于快速构建交互式 Web 应用,适合演示、测试和初步上线。
1. 安装 Gradio
pip install gradio
2. 代码示例:图像分类模型部署(TorchScript)
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image# 加载 TorchScript 模型
model = torch.jit.load("model_traced.pt")
model.eval()# 图像预处理函数
preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()
])# 推理函数
def predict(img):img = preprocess(img).unsqueeze(0)with torch.no_grad():output = model(img)probs = torch.nn.functional.softmax(output[0], dim=0)return {f"Class {i}": float(p) for i, p in enumerate(probs)}# 创建界面
iface = gr.Interface(fn=predict, inputs="image", outputs="label")
iface.launch()
启动后会自动打开浏览器访问地址,如:http://127.0.0.1:7860
方式二:使用 Flask 构建 RESTful 接口(API)
Flask 是 Python 中常用的 Web 框架,可以把模型封装成一个 HTTP 接口供前端或其他服务调用。
1. 安装 Flask
pip install flask
2. API 接口部署模板(适合 ONNX)
from flask import Flask, request, jsonify
import onnxruntime as ort
import numpy as npapp = Flask(__name__)# 初始化 ONNX 推理器
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name@app.route("/predict", methods=["POST"])
def predict():data = request.json["input"] # 输入是嵌套 listinput_array = np.array(data).astype(np.float32)output = session.run(None, {input_name: input_array})return jsonify({"output": output[0].tolist()})if __name__ == '__main__':app.run(debug=True, port=5000)
前端可以通过 POST
请求向 /predict
发送数据,返回 JSON 格式的模型输出。
方式三:部署到 HuggingFace Spaces(在线部署平台)
HuggingFace 提供免费的部署平台,支持 Gradio/Streamlit 应用的在线托管。
步骤:
-
在 https://huggingface.co/spaces 创建一个新的 Space;
-
选择 Gradio 模板;
-
上传你的代码文件(如
app.py
)和requirements.txt
; -
提交后等待构建,即可访问。
示例
gradio
torch
torchvision
方式四:ONNX Runtime + FastAPI + Docker(工业部署)
适合构建高性能、可扩展的 API 服务。
-
使用 FastAPI 替代 Flask(性能更高);
-
使用 Docker 打包(环境一致性);
-
使用 ONNX Runtime(加速推理);
若你感兴趣,我可以提供该方式的完整项目结构与部署脚本。
常见部署注意事项
问题/注意点 | 说明 |
---|---|
模型文件太大 | 可用 torch.quantization 压缩模型 |
GPU/CPU 版本不一致 | 部署前明确目标环境是否支持 CUDA |
接口响应慢 | 用 FastAPI + Uvicorn 替代 Flask |
高并发请求处理困难 | 使用 Gunicorn 或 Docker + Kubernetes |
数据预处理慢 | 把预处理逻辑也放在服务端完成 |
服务崩溃/异常退出 | 加入异常处理与日志记录系统 |
第六部分:高级部署与优化技巧(模型压缩、推理加速、Docker 打包、前端集成)
当你完成了模型部署的基本流程,进一步优化部署效果(速度、稳定性、易用性)就很关键了。下面我们从多个方面介绍进阶技巧。
一、模型压缩与推理加速
部署模型时,常常遇到模型太大、推理太慢、占用资源高等问题。可以通过以下几种方式进行模型压缩和推理加速。
1. 模型量化(Quantization)
将浮点数权重压缩成更小的数据类型(如 float16
或 int8
),大幅降低模型大小和推理耗时。
静态量化(Post-training)示例:
import torch.quantizationmodel_fp32 = ... # 已训练模型
model_fp32.eval()# 准备量化配置
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_fp32, inplace=True)# 运行一次推理(用于收集统计信息)
_ = model_fp32(torch.randn(1, 3, 224, 224))# 转换为量化模型
quantized_model = torch.quantization.convert(model_fp32, inplace=False)# 保存
torch.jit.script(quantized_model).save("model_quant.pt")
注意:部分模块(如
BatchNorm
)不支持直接量化,需使用QuantStub
包装。
2. 使用 Torch-TensorRT(NVIDIA GPU 加速)
Torch-TensorRT 是 NVIDIA 供的一个库,可将 TorchScript 模型转换为 TensorRT 引擎
pip install torch-tensorrt -U
简单使用:
import torch_tensorrttrt_model = torch_tensorrt.compile(model, inputs=[torch.randn(1, 3, 224, 224).to("cuda")], enabled_precisions={torch.float16})
✅ 二、Docker 化部署(推荐生产环境使用)
Docker 可以把你的服务打包成镜像,确保环境一致性、可移植性。
1. 创建项目目录结构
deploy_app/
├── app.py # Flask / Gradio 应用
├── model.onnx # 导出的模型
├── requirements.txt # 所需 Python 包
└── Dockerfile # Docker 构建脚本
2. Dockerfile 示例
FROM python:3.10WORKDIR /appCOPY requirements.txt .
RUN pip install -r requirements.txtCOPY . .CMD ["python", "app.py"]
3. 构建并运行容器
docker build -t my_model_app .
docker run -p 5000:5000 my_model_app
若部署到云端(如阿里云、AWS),推荐结合 Nginx 反向代理与容器编排(如 docker-compose 或 Kubernetes)。
三、前端集成与美化建议
技术 | 优点 | 示例用途 |
---|---|---|
Gradio | 快速搭建交互界面 | 原型演示、测试用 |
Streamlit | 数据可视化友好 | 图像/表格/图表展示等 |
HTML + JS | 适合自定义界面、美化展示 | 嵌入 Web 系统、企业平台 |
React/Vue | 高度定制、适合商用产品 | 构建完整 Web 应用 |
四、完整部署案例:PyTorch → ONNX → Gradio → Docker → HuggingFace Spaces
总结与建议
部分 | 内容概览 |
---|---|
第一部分 | 模型保存格式(权重、结构、完整模型) |
第二部分 | 加载与恢复模型的多种方式 |
第三部分 | 部署的基本概念与分类 |
第四部分 | 模型导出为 TorchScript / ONNX |
第五部分 | 使用 Gradio / Flask / ONNX Runtime 部署 |
第六部分 | 模型压缩、推理加速、Docker 化、高级部署建议 |