当前位置: 首页 > news >正文

使用 Flask 实现本机 PyTorch 模型部署:从服务端搭建到客户端调用

目录

前言

一、部署前准备

1.1 环境要求

1.2 必备文件

二、服务端搭建:让模型 “听候指令”

2.1 服务端完整代码(server.py)

2.2 服务端启动与验证

三、客户端开发:向服务端 “发送请求”

3.1 客户端完整代码(client.py)

3.2 客户端运行与结果示例

四、常见问题排查

4.1 服务端启动失败

4.2 客户端连接失败

4.3 模型推理报错

五、扩展与优化建议

总结


前言

在机器学习项目中,训练好的模型只有部署到实际环境中才能发挥价值。对于本机测试或小规模应用场景,Flask 框架是实现模型部署的轻量优选 —— 它能快速搭建 HTTP 服务,让模型以接口形式接收请求、返回预测结果,无需复杂的服务器配置。

本文将以 ResNet18 图像分类模型为例,完整讲解如何用 Flask 实现 “本机模型部署”:从服务端代码编写(模型加载、接口定义),到客户端代码开发(图像上传、结果解析),再到常见问题排查,确保新手也能一步到位跑通流程。

一、部署前准备

在开始编写代码前,需先确认环境和依赖是否齐全,避免后续因版本或包缺失导致报错。

1.1 环境要求

  • Python 版本:3.7~3.9(PyTorch 对高版本 Python 兼容性可能不稳定)
  • 核心依赖包:
# 安装Flask(Web服务框架)
pip install flask
# 安装PyTorch+TorchVision(模型加载与图像预处理)
pip install torch torchvision
# 安装PIL(图像读取处理)和requests(客户端请求)
pip install pillow requests

1.2 必备文件

  • 训练好的模型权重文件:本文使用best.pth(ResNet18 微调后权重,需确保与代码中类别数匹配)
  • 测试图像:准备 1~2 张用于验证的图像(如 JPG/PNG 格式)
  • 代码结构:建议按如下目录组织,避免路径混乱

二、服务端搭建:让模型 “听候指令”


服务端的核心作用是:加载预训练模型、定义预测接口、监听本机请求。当客户端发送图像请求时,服务端会完成图像预处理、模型推理,并返回 JSON 格式的预测结果。

2.1 服务端完整代码(server.py)

import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models# 1. 初始化Flask应用
app = flask.Flask(__name__)  # __name__定位应用根路径,用于查找静态资源
model = None  # 全局变量存储模型,避免重复加载
use_gpu = False  # 本机部署默认用CPU(若有GPU可设为True,需确保PyTorch支持CUDA)# 2. 加载预训练模型
def load_model():"""加载ResNet18模型,替换全连接层适配自定义分类任务"""global model# 加载ResNet18基础网络(pretrained=False表示不加载默认预训练权重,用自己的best.pth)model = models.resnet18(pretrained=False)# 获取全连接层输入特征数,替换为自定义类别数(本文以102类为例)num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 输出维度=类别数# 加载训练好的权重文件(需确保best.pth路径正确)checkpoint = torch.load('best.pth', map_location=torch.device('cpu'))  # 强制CPU加载,避免GPU报错model.load_state_dict(checkpoint['state_dict'])  # 加载权重参数# 设为评估模式(禁用Dropout、BatchNorm等训练特有的层)model.eval()# 若启用GPU且设备支持,将模型移至CUDAif use_gpu and torch.cuda.is_available():model = model.cuda()print("模型加载完成,等待请求...")# 3. 图像预处理函数(需与训练时保持一致)
def prepare_image(image, target_size=(224, 224)):"""将客户端传入的图像转为模型可接受的Tensor格式"""# 统一图像为RGB格式(避免灰度图/透明图报错)if image.mode != 'RGB':image = image.convert('RGB')# 预处理 pipeline:Resize→ToTensor→Normalize(ResNet默认预处理参数)preprocess = transforms.Compose([transforms.Resize(target_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image_tensor = preprocess(image)# 增加batch维度(模型要求输入为[batch_size, C, H, W],单张图batch_size=1)image_tensor = image_tensor.unsqueeze(0)  # 等价于image_tensor[None]# 若启用GPU,将Tensor移至CUDAif use_gpu and torch.cuda.is_available():image_tensor = image_tensor.cuda()return image_tensor# 4. 定义预测接口(POST方法)
@app.route("/predict", methods=["POST"])
def predict():"""接收客户端POST请求:- 请求体包含image字段(二进制图像)- 返回JSON格式结果:success(请求状态)、predictions(Top3预测结果)"""# 初始化返回结果字典result = {"success": False, "predictions": []}# 检查请求方法是否为POST,且包含image文件if flask.request.method == "POST" and flask.request.files.get("image"):# 步骤1:读取客户端传入的二进制图像image_bytes = flask.request.files["image"].read()# 将二进制数据转为PIL图像对象image = Image.open(io.BytesIO(image_bytes))# 步骤2:图像预处理image_tensor = prepare_image(image, target_size=(224, 224))# 步骤3:模型推理(禁用梯度计算,加速推理)with torch.no_grad():# 计算各类别概率(softmax归一化)preds = F.softmax(model(image_tensor), dim=1)# 获取概率Top3的类别和概率值(cpu()转为CPU张量,避免GPU与CPU数据冲突)top3_probs, top3_labels = torch.topk(preds.cpu(), k=3, dim=1)# 步骤4:处理结果(转为numpy数组→构造JSON格式)top3_probs = top3_probs.numpy()[0]  # 取第1个batch(仅1张图)top3_labels = top3_labels.numpy()[0]# 遍历Top3结果,添加到返回字典for prob, label in zip(top3_probs, top3_labels):result["predictions"].append({"label": str(label),  # 类别标签(若有类别名映射,可此处替换为中文)"probability": round(float(prob), 4)  # 概率值(保留4位小数)})# 标记请求成功result["success"] = True# 以JSON格式返回结果(Flask自动设置Content-Type为application/json)return flask.jsonify(result)# 5. 启动服务
if __name__ == "__main__":# 先加载模型(确保模型加载成功后再启动服务)try:load_model()except Exception as e:print(f"模型加载失败:{str(e)}")exit(1)  # 模型加载失败则退出程序# 启动Flask服务(本机部署关键参数)# host='127.0.0.1':仅本机可访问(推荐本机测试用)# port=5012:端口号(避免与其他服务冲突,如8080、5000)app.run(host='127.0.0.1', port=5012, debug=False)

2.2 服务端启动与验证


1. 运行server.py,若控制台输出以下内容,说明服务启动成功:

2.初步验证:打开浏览器访问http://127.0.0.1:5012/predict,若返回{"success":false,"predictions":[]},证明接口正常监听(无图像请求时返回默认状态)。

三、客户端开发:向服务端 “发送请求”


客户端的作用是:读取本地图像、以 POST 方式向服务端接口发送请求、解析返回的 JSON 结果并打印。

3.1 客户端完整代码(client.py)

import requests# 1. 配置服务端地址(需与服务端host和port一致)
# 本机部署用http://127.0.0.1:5012/predict,跨设备需替换为服务端局域网IP
FLASK_URL = "http://127.0.0.1:5012/predict"def predict_image(image_path):"""向服务端发送图像预测请求:param image_path: 本地图像路径(如"./test_img/image_06975.jpg")return: 打印预测结果"""try:# 步骤1:以二进制形式读取本地图像(保持图像原始格式)with open(image_path, 'rb') as f:image_bytes = f.read()# 步骤2:构造请求体(key为"image",与服务端flask.request.files.get("image")对应)payload = {"image": image_bytes}# 步骤3:发送POST请求(timeout设为10秒,避免请求超时)response = requests.post(FLASK_URL, files=payload, timeout=10)# 步骤4:解析响应结果(JSON→字典)result = response.json()# 步骤5:判断请求是否成功并打印结果if response.status_code == 200 and result["success"]:print("请求成功(状态码:200)")print("Top3预测结果:")for i, pred in enumerate(result["predictions"], 1):print(f"  {i}. 类别:{pred['label']},概率:{pred['probability']}")else:print(f"请求失败:{result}")except Exception as e:print(f"客户端报错:{str(e)}")# 6. 运行客户端(测试单张图像)
if __name__ == "__main__":# 替换为你的测试图像路径(相对路径/绝对路径均可)test_image_path = "./test_img/image_06975.jpg"predict_image(test_image_path)

3.2 客户端运行与结果示例


1. 确保服务端已启动,运行client.py,若请求成功,控制台输出如下:

状态码说明:
◦ 200:请求成功(服务端正常处理)
◦ 404:接口地址错误(如 URL 写错)
◦ 500:服务端内部错误(如模型加载失败、代码报错)

四、常见问题排查

在部署过程中,新手容易遇到连接失败、模型报错等问题,以下是高频问题的解决方案:

4.1 服务端启动失败

  • 报错 1:FileNotFoundError: [Errno 2] No such file or directory: 'best.pth'原因:模型权重文件路径错误。解决:确认best.pthserver.py在同一目录,或使用绝对路径(如torch.load("C:/model_deployment/best.pth"))。

  • 报错 2:WinError 10049 请求的地址无效原因:app.run(host=...)中 IP 地址错误(非本机 IP)。解决:本机部署用host='127.0.0.1',或通过ipconfig(Windows)/ifconfig(Linux)查看本机正确 IP。

4.2 客户端连接失败

  • 报错 1:requests.exceptions.ConnectionError: HTTPConnectionPool原因:服务端未启动,或 IP / 端口不匹配。解决:先启动服务端,确认客户端FLASK_URL与服务端host:port完全一致(如均为127.0.0.1:5012)。

  • 报错 2:PIL.UnidentifiedImageError: cannot identify image file原因:图像路径错误,或文件不是有效图像(如后缀为.jpg 但实际是.txt)。解决:检查图像路径,用画图工具打开图像确认是否正常。

4.3 模型推理报错

  • 报错:RuntimeError: Expected 4-dimensional input for 4-dimensional weight原因:图像未增加 batch 维度(模型要求输入为[batch_size, C, H, W])。解决:确保prepare_image函数中调用image_tensor.unsqueeze(0)

五、扩展与优化建议

  1. 类别名映射:当前返回的是类别编号(如 35),可添加字典映射为中文(如label_map = {35: "玫瑰", 36: "百合"}),在服务端predict函数中替换"label": str(label)"label": label_map[label]

  2. 多图批量预测:修改客户端代码,支持遍历文件夹下所有图像,批量发送请求。

  3. 生产环境优化:Flask 开发服务器不适合生产环境,可改用Gunicorn(Linux)或Waitress(Windows)作为 WSGI 服务器,搭配Nginx反向代理,提升并发能力。

总结

本文通过 Flask 框架实现了本机 PyTorch 模型的完整部署流程:服务端负责加载模型和提供接口,客户端负责发送请求和解析结果,整个流程轻量、易上手,适合小规模测试或个人项目使用。

核心要点可总结为 3 点:

  1. 服务端与客户端的host:port必须一致;
  2. 图像预处理需与训练时保持一致(如 Resize 尺寸、Normalize 参数);
  3. 先启动服务端,再运行客户端,避免连接失败。

按照本文步骤操作,即可快速将自己的 PyTorch 模型部署到本机,实现 “训练→部署→调用” 的闭环。

http://www.dtcms.com/a/461684.html

相关文章:

  • sql题目练习——多表查询
  • c 做网站加载多个图片网站开发实战第二章
  • 精通C语言(3. 自定义类型:联合体和枚举)
  • 认知事物的三个层次
  • 做数学题目在哪个网站好设计好的装修公司
  • 09.Linux环境变量
  • 11、规划过程组(4):风险
  • HT8698 立体声 D 类音频功率放大器:性能参数介绍
  • 做亚克力在那个网站上好上海建工一建集团有限公司
  • DOM与BOM核心用法解析
  • 如何在网站上做跳转代码最好的科技资讯网站
  • 【项目】自然语言处理——情感分析 <下>
  • 网站页面制作公司外部网站 同意加载
  • 高通平台WiFi学习--IPv6 邻居发现卸载:Wi-Fi 固件助力移动设备功耗优化
  • 网站备案 工信部如何做建材团购网站
  • 知名的咨询行业网站制作茂名网站开发
  • 网络管理部分
  • 小白逆袭----2025了,彻底弄懂react-test单元测试 基础使用(一)
  • 代做标书网站政务网站建设模块
  • (2)100天python从入门到拿捏
  • 我的云函数向 unicloud 数据库存储数据问什么 grade 字段无法存储?
  • 免费下载代码的网站做网站国外网站
  • npm install --legacy-peer-deps:它到底做了什么,什么时候该用?
  • [Tongyi] 工具集成 | run_react_infer
  • 做课题查新网站茶叶网站开发目的和意义
  • 第5章 高效的多线程日志
  • 平安建设 十户长网站地址织梦网站制作教程
  • 无人机图传系统解析:模拟与数字的应用及未来趋势,无人机图传的作用
  • Agentic AI 与 AI 编程入门:让 AI 成为学习与创作的最佳伴侣
  • CF45C Dancing Lessons 题解