基于 Flask的深度学习模型部署服务端详解
基于 Flask 的深度学习模型部署服务端详解
在深度学习领域,训练出一个高精度的模型只是第一步,将其部署到生产环境中,为实际业务提供服务才是最终目标。本文将详细解析一个基于 Flask 和 PyTorch 的深度学习模型部署服务端代码,帮助你理解如何将训练好的模型以 API 形式提供给客户端使用。
一、整体概述
这段代码的主要功能是搭建一个基于 Flask 的 Web 服务,用于接收客户端发送的图像数据,使用预训练的 PyTorch 模型对图像进行分类预测,并将预测结果以 JSON 格式返回给客户端。
二、代码详细解析
1. 导入必要的库
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
io
:用于处理二进制数据,这里主要用于将客户端发送的图像二进制数据转换为图像对象。flask
:一个轻量级的 Web 框架,用于搭建 Web 服务。torch
和torch.nn.functional
:PyTorch 的核心库,用于深度学习模型的构建和计算。PIL.Image
:Python Imaging Library(PIL)的一部分,用于处理图像文件。torch.nn
:用于定义神经网络的层和模块。torchvision.transforms
和torchvision.models
:transforms
用于图像预处理,models
提供了预训练的深度学习模型。
2. 初始化 Flask 应用和模型相关变量
app = flask.Flask(__name__)
model = None
use_gpu = False
app = flask.Flask(__name__)
:创建一个新的 Flask 应用实例,__name__
参数用于确定应用的根路径。model
:用于存储加载的深度学习模型,初始化为None
。use_gpu
:一个布尔变量,用于控制是否使用 GPU 进行模型推理,初始化为False
。
3. 加载模型
def load_model():global modelmodel = models.resnet18()num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))checkpoint = torch.load('best.pth')model.load_state_dict(checkpoint['state_dict'])model.eval()if use_gpu:model.cuda()
global model
:声明model
为全局变量,以便在函数内部修改它。model = models.resnet18()
:加载预训练的 ResNet-18 模型。num_ftrs = model.fc.in_features
:获取 ResNet-18 模型最后一层全连接层的输入特征数。model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
:修改最后一层全连接层,将输出维度改为 102,这里的 102 可以根据实际任务的类别数进行调整。checkpoint = torch.load('best.pth')
:从文件best.pth
中加载训练好的模型参数。model.load_state_dict(checkpoint['state_dict'])
:将加载的参数应用到模型中。model.eval()
:将模型设置为评估模式,关闭一些在训练时使用的特殊层(如 Dropout)。if use_gpu: model.cuda()
:如果use_gpu
为True
,将模型移动到 GPU 上。
4. 图像预处理
def prepare_image(image, target_size):if image.mode != 'RGB':image = image.convert('RGB')image = transforms.Resize(target_size)(image)image = transforms.ToTensor()(image)image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)image = image[None]if use_gpu:image = image.cuda()return torch.tensor(image)
if image.mode != 'RGB': image = image.convert('RGB')
:确保输入图像为 RGB 格式。image = transforms.Resize(target_size)(image)
:将图像调整为指定的大小。image = transforms.ToTensor()(image)
:将图像转换为 PyTorch 张量。image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
:对图像进行归一化处理,使用的均值和标准差是在 ImageNet 数据集上计算得到的。image = image[None]
:增加一个维度,将图像转换为批量输入的格式。if use_gpu: image = image.cuda()
:如果use_gpu
为True
,将图像移动到 GPU 上。
5. 定义预测接口
@app.route('/predict', methods=['POST'])
def predict():data = {'success': False}if flask.request.method == 'POST':if flask.request.files.get('image'):image = flask.request.files['image'].read()image = Image.open(io.BytesIO(image))image = prepare_image(image, target_size=(224, 224))preds = F.softmax(model(image), dim=1)results = torch.topk(preds.cpu().data, k=3, dim=1)results = (results[0].cpu().numpy(), results[1].cpu().numpy())data['prediction'] = list()for prob, label in zip(results[0][0], results[1][0]):r = {'label': str(label), 'probability': float(prob)}data['prediction'].append(r)data['success'] = Truereturn flask.jsonify(data)
@app.route('/predict', methods=['POST'])
:使用 Flask 的装饰器定义一个路由,当客户端向/predict
路径发送 POST 请求时,会调用predict
函数。data = {'success': False}
:初始化一个字典,用于存储预测结果和状态信息,初始状态为success = False
。if flask.request.method == 'POST'
:检查请求方法是否为 POST。if flask.request.files.get('image')
:检查请求中是否包含名为image
的文件。image = flask.request.files['image'].read()
:读取客户端发送的图像文件内容。image = Image.open(io.BytesIO(image))
:将二进制数据转换为图像对象。image = prepare_image(image, target_size=(224, 224))
:对图像进行预处理。preds = F.softmax(model(image), dim=1)
:使用模型进行预测,并通过softmax
函数将输出转换为概率分布。results = torch.topk(preds.cpu().data, k=3, dim=1)
:获取概率最大的前 3 个结果。results = (results[0].cpu().numpy(), results[1].cpu().numpy())
:将结果转换为 NumPy 数组。data['prediction'] = list()
:初始化一个列表,用于存储预测结果。for prob, label in zip(results[0][0], results[1][0])
:遍历前 3 个结果,将标签和概率封装成字典,并添加到data['prediction']
列表中。data['success'] = True
:将状态信息设置为success = True
,表示预测成功。return flask.jsonify(data)
:将结果以 JSON 格式返回给客户端。
6. 启动服务
if __name__ == '__main__':print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started')load_model()app.run(host='192.168.1.20', port=5012)
if __name__ == '__main__'
:确保代码作为主程序运行时才执行以下操作。print('Loading PyTorch model and Flask starting server ...')
和print('Please wait until server has fully started')
:打印启动信息。load_model()
:调用load_model
函数加载模型。app.run(host='192.168.1.20', port=5012)
:启动 Flask 服务,监听192.168.1.20
地址的 5012 端口。运行结果如下
三、总结
通过上述代码,我们成功搭建了一个基于 Flask 和 PyTorch 的深度学习模型部署服务端。客户端可以通过向 /predict
路径发送包含图像文件的 POST 请求,获取图像分类的预测结果。在实际应用中,可以根据需要对代码进行扩展,如增加更多的模型、优化图像预处理流程、添加错误处理机制等。希望本文能帮助你更好地理解深度学习模型的部署过程。