当前位置: 首页 > news >正文

PyTorch 模型部署实战:用 Flask 搭图像分类 API

在 AI 项目开发中,训练好的模型只有部署成可调用的服务,才能真正落地产生价值。本文将手把手教你用 ResNet18 预训练模型,结合 Flask 框架搭建图像分类 API,并编写客户端程序实现图片上传与结果接收,全程代码可直接复用。

一、项目整体架构

整个项目分为 “服务端” 和 “客户端” 两部分,核心是通过 HTTP 请求实现数据交互。

  • 服务端:基于 Flask 搭建 API 接口,加载 ResNet18 图像分类模型,接收客户端上传的图片,完成预测后返回结果。
  • 客户端:读取本地图片,通过 POST 请求将图片发送到服务端,接收并解析预测结果,最终打印展示。

两者的交互流程非常简洁:客户端传图→服务端预测→服务端返结果→客户端显结果。

二、服务端开发:搭建图像分类 API

服务端是整个系统的核心,需要完成模型加载、图片预处理、预测逻辑和 API 接口定义四个关键步骤。

1. 依赖库安装

首先确保安装所需的 Python 库,直接用 pip 安装即可:

pip install flask torch torchvision pillow requests

2. 核心代码解析

服务端代码(命名为image_classification_server.py)分为 5 个模块,每个模块功能清晰,可直接复制使用。

(1)导入依赖库

先引入所有需要的工具包,涵盖 Web 服务、模型框架、图像处理等领域:

import io
import flask
import torch
import torch.nn.functional as F
from torch import nn
from PIL import Image
from torchvision import transforms, models
(2)初始化 Flask 应用与模型变量

创建 Flask 实例,并定义模型和 GPU 使用标志(默认用 CPU,避免环境依赖):

app = flask.Flask(__name__)
model = None  # 全局模型变量,加载后赋值
use_gpu = False  # 可根据硬件情况改为True
(3)模型加载函数

加载预训练的 ResNet18 模型,替换全连接层适配 102 类分类任务(如花卉分类),并加载训练好的权重文件(best.pth):

def load_model():global model# 1. 加载ResNet18基础模型model = models.resnet18()# 2. 替换全连接层:ResNet18默认输出1000类,改为102类num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))# 3. 加载训练好的权重(需确保best.pth在当前目录)checkpoint = torch.load('best.pth', map_location=torch.device('cpu'))model.load_state_dict(checkpoint['state_dict'])# 4. 设置为评估模式(禁用 dropout 等训练特有的层)model.eval()# 5. 若使用GPU且设备可用,将模型移到GPUif use_gpu and torch.cuda.is_available():model.cuda()
(4)图像预处理函数

将客户端上传的图片转换成模型要求的输入格式(ResNet 系列默认输入为 224×224,且需用 ImageNet 数据集的均值和标准差归一化):

def prepare_image(image, target_size=(224, 224)):# 定义图像转换流程transform = transforms.Compose([transforms.Resize(target_size),  # 缩放图片transforms.ToTensor(),  # 转为Tensor(维度:C×H×W)transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet均值std=[0.229, 0.224, 0.225]    # ImageNet标准差)])# 应用转换并添加批次维度(模型要求输入为[batch_size, C, H, W])image = transform(image).unsqueeze(0)# 若使用GPU,将图片移到GPUif use_gpu and torch.cuda.is_available():image = image.cuda()return image
(5)API 接口定义

定义两个核心接口:/predict(处理图片预测)和/health(服务健康检查)。

  • /predict 接口:接收 POST 请求,处理图片上传、预测和结果返回:
@app.route('/predict', methods=['POST'])
def predict():# 初始化响应字典(默认失败)data = {"success": False}# 检查请求方法和是否包含图片if flask.request.method == 'POST' and flask.request.files.get('image'):# 1. 读取图片:从请求中获取二进制图片数据,转为PIL图像image_bytes = flask.request.files['image'].read()image = Image.open(io.BytesIO(image_bytes)).convert('RGB')  # 确保为RGB格式# 2. 预处理图片image = prepare_image(image)# 3. 模型预测(禁用梯度计算,提升速度)with torch.no_grad():outputs = model(image)  # 模型输出(logits)probabilities = F.softmax(outputs, dim=1)  # 转为概率top_k_prob, top_k_indices = torch.topk(probabilities, 5)  # 获取前5个预测结果# 4. 处理结果:转为Python原生类型,构造返回格式results = []for i in range(top_k_prob.size(1)):results.append({"label": int(top_k_indices[0][i]),  # 类别编号"probability": float(top_k_prob[0][i])  # 对应概率})# 5. 更新响应数据(标记成功,添加预测结果)data["predictions"] = resultsdata["success"] = True# 返回JSON格式响应return flask.jsonify(data)
  • /health 接口:用于检查服务是否正常运行、模型是否加载成功:
@app.route('/health', methods=['GET'])
def health_check():return flask.jsonify({"status": "healthy", "model_loaded": model is not None})
(6)启动服务

在主函数中加载模型,并启动 Flask 服务(默认端口 5012,允许外部访问):

if __name__ == '__main__':print('Loading model... Please wait...')load_model()  # 启动前加载模型print('Model loaded successfully!')# 启动服务:host设为0.0.0.0可允许同一局域网内其他设备访问app.run(port=5012, host='0.0.0.0')

三、客户端开发:实现图片上传与结果接收

客户端代码(命名为image_classification_client.py)功能简单:读取本地图片,发送到服务端,解析并打印预测结果。

1. 核心代码解析

import requests# 服务端API地址(本地测试用127.0.0.1,局域网访问需替换为服务端IP)
flask_url = 'http://127.0.0.1:5012/predict'def predict_result(image_path):# 1. 读取本地图片(二进制模式)with open(image_path, 'rb') as f:image_data = f.read()# 2. 构造请求参数(key为'image',与服务端接收的字段一致)payload = {'image': image_data}# 3. 发送POST请求,解析JSON响应response = requests.post(flask_url, files=payload).json()# 4. 处理响应结果if response['success']:# 打印前5个预测结果(类别编号+概率)print("预测结果:")for i, result in enumerate(response['predictions']):print(f"{i+1}. 类别编号:{result['label']},概率:{result['probability']:.4f}")else:print("请求失败,请检查服务端或图片路径!")# 主函数:调用预测函数(替换为你的图片路径)
if __name__ == '__main__':predict_result('./flower_data/val_filelist/image_00059.jpg')

四、项目运行步骤

按照以下步骤操作,即可快速跑通整个流程:

  1. 准备模型权重:将训练好的best.pth文件放到服务端代码同一目录(若没有,可先训练一个 102 类分类模型,或修改代码适配你的类别数)。
  2. 启动服务端:运行服务端代码,看到 “Model loaded successfully!” 表示启动成功:

    bash

    python image_classification_server.py
    
  3. 运行客户端:在另一个终端运行客户端代码,确保图片路径正确,即可看到预测结果:

    bash

    python image_classification_client.py
    

五、项目扩展方向

本项目是一个基础的图像分类 API 框架,可根据需求进行扩展:

  • 添加类别名称映射:目前返回的是类别编号,可添加一个字典(如label2name = {0: '玫瑰', 1: '百合'...}),将编号转为具体名称。
  • 增加图片格式校验:在服务端添加对图片格式(如 JPG、PNG)和大小的校验,提升鲁棒性。
  • 部署到云服务器:将服务端部署到阿里云、腾讯云等平台,配置域名和 HTTPS,实现公网访问。
  • 添加请求限流:使用 Flask-Limiter 等库限制接口调用频率,防止恶意请求。

总结

本文通过 ResNet18+Flask 实现了图像分类 API 的快速搭建,从代码解析到运行步骤都做了详细说明,即使是初学者也能轻松上手。这个框架不仅适用于图像分类,稍作修改后还可用于目标检测、图像分割等其他计算机视觉任务,具有很强的通用性

http://www.dtcms.com/a/462638.html

相关文章:

  • 如何进行目的地网站建设东莞厚街创新科技职业学院
  • 网站标题修改重庆网站设计排名
  • 做图表用的网站做网站有什么工具
  • 温州做网站seophp网站开发就业前景
  • 问题记录:一个简单的字符串正则匹配算法引发的 CPU 告警
  • 公共数据资源的“整体授权”是什么涵义?
  • 如何增加网站关键词密度网站建设与维护网课
  • 建立门户网站的程序漳州企业网站建设制作
  • 汕头房产网站建设公司网站界面设计
  • [7-01-02].第05节:环境搭建 - 基础环境
  • BLIP模型
  • 网站建设添加资料搜索引擎优化seo什么意思
  • Playwright与Python:从入门到精通的完整指南
  • maven本地仓库有相应的依赖,依旧会从远程仓库拉取问题的原因及解决
  • 如何修改wordpress站景区旅游网站平台建设方案
  • 网站建设拾金手指下拉十九济南天桥区做网站的
  • 甘肃水利工程建设管理网站东省住房和城乡建设厅网站
  • 10.9 换根dp
  • 上海做网站设计温州专业微网站制作多少钱
  • Trino:一个开源分布式大数据SQL查询引擎
  • 网站建设岗位职责做网站能致富吗
  • 网站优化方案设计wordpress删除用户头像
  • C# 弃元模式:从语法糖到性能利器的深度解析
  • 外国优秀网站欣赏广东茂名网站建设
  • 网站备案的用户名是什么广州比较好的网站建设
  • INT301 Bio-computation 生物计算(神经网络)Pt.1 导论与Hebb学习规则
  • 百度站长平台男女做暖暖的网站大全
  • 乌克兰集团网站建设wordpress 产品目录
  • C#基础16-C#6-C#9新特性
  • 两个RNA-蛋白以及蛋白间相互作用数据库