使用Flask部署PyTorch模型
1. Flask框架简介与优势
Flask是一个轻量级的Python Web应用框架,以其微核心架构著称。它采用Werkzeug WSGI工具包和Jinja2模板引擎作为基础,同时保持核心简单但可扩展的特点。主要优势包括:
- 灵活性强:Flask不强制使用特定的项目结构或数据库ORM,开发者可以自由选择组件
- 学习曲线平缓:基本API简单直观,新手可以在几小时内掌握核心功能
- 扩展性好:通过Flask扩展生态系统(如Flask-RESTful、Flask-SQLAlchemy)可以轻松添加功能
1.1 与其他框架的对比
特性 | Flask | Django | FastAPI |
---|---|---|---|
架构类型 | 微框架 | 全功能框架 | 异步框架 |
学习难度 | 低 | 中等 | 中等 |
内置功能 | 少 | 多(Admin,ORM等) | 中等 |
性能 | 中等 | 中等 | 高 |
适合场景 | 小型API,快速原型 | 企业级应用 | 高性能API |
1.2 机器学习模型部署中的优势
对于机器学习模型API服务,Flask特别适合因为:
-
快速原型开发示例:
from flask import Flask app = Flask(__name__)@app.route('/predict', methods=['POST']) def predict():return {'result': 'prediction'}if __name__ == '__main__':app.run()
只需5行代码即可创建功能API端点
-
易于集成:Flask的请求-响应循环与PyTorch/TensorFlow的推理流程天然契合
-
灵活性高:可以自由定制:
- 输入数据预处理管道
- 模型并行推理策略
- 复杂的结果后处理逻辑
2. 环境准备与依赖安装
2.1 基础环境配置
推荐使用Python 3.7+环境,通过virtualenv或conda创建隔离环境:
python -m venv flask-env
source flask-env/bin/activate # Linux/Mac
flask-env\Scripts\activate # Windows
2.2 核心依赖安装
pip install Flask torch torchvision Pillow
各包的作用:
Flask
: Web框架核心torch
&torchvision
: PyTorch深度学习框架及视觉工具Pillow
: 图像处理库
2.3 GPU支持验证
对于需要GPU加速的场景:
- 首先确认已安装对应版本的CUDA和cuDNN
- 使用以下命令验证PyTorch GPU支持:
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"GPU数量: {torch.cuda.device_count()}")
print(f"当前GPU: {torch.cuda.current_device()}")
print(f"GPU名称: {torch.cuda.get_device_name(0)}")
输出示例:
PyTorch版本: 1.12.1+cu113
CUDA可用: True
GPU数量: 1
当前GPU: 0
GPU名称: NVIDIA GeForce RTX 3090
3. 完整代码解析
3.1 应用初始化
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models# 初始化Flask应用
app = flask.Flask(__name__)# 全局模型变量
model = None
use_gpu = False # 根据实际情况调整
关键点说明:
__name__
参数帮助Flask定位资源文件路径- 模型设为全局变量避免重复加载
use_gpu
标志控制是否使用GPU加速
3.2 模型加载函数
def load_model():"""加载预训练模型"""global model# 1. 初始化模型结构model = models.resnet18(pretrained=False)# 2. 修改最后一层适配自定义任务num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102), # 假设是102分类任务nn.LogSoftmax(dim=1))# 3. 加载训练好的权重checkpoint = torch.load('best.pth', map_location='cpu')model.load_state_dict(checkpoint['state_dict'])# 4. 设置为评估模式model.eval()# 5. 可选GPU加速if use_gpu and torch.cuda.is_available():model.cuda()
3.3 图像预处理
def prepare_image(image, target_size):"""标准化输入图像"""# 统一格式if image.mode != 'RGB':image = image.convert('RGB')# 预处理管道transform = transforms.Compose([transforms.Resize(target_size),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])# 应用转换并添加batch维度image = transform(image).unsqueeze(0)if use_gpu and torch.cuda.is_available():image = image.cuda()return image
3.4 API端点设计
@app.route("/predict", methods=["POST"])
def predict():# 初始化响应数据data = {"success": False, "error": None}try:# 验证请求方法if flask.request.method != 'POST':raise ValueError("Only POST method is supported")# 检查文件上传if not flask.request.files.get("image"):raise ValueError("No image file provided")# 读取并预处理图像image = flask.request.files["image"].read()image = Image.open(io.BytesIO(image))image = prepare_image(image, (224, 224))# 模型推理with torch.no_grad():outputs = model(image)probs = F.softmax(outputs, dim=1)top_probs, top_labels = torch.topk(probs, k=3)# 格式化结果data['predictions'] = [{"label": str(label.item()),"probability": round(prob.item(), 4)}for prob, label in zip(top_probs[0], top_labels[0])]data["success"] = Trueexcept Exception as e:data["error"] = str(e)return flask.jsonify(data)
3.5 服务启动
if __name__ == '__main__':print("Loading PyTorch model and starting Flask server...")load_model()app.run(host='0.0.0.0', port=5012, threaded=True)
启动参数说明:
host='0.0.0.0'
允许外部访问threaded=True
启用多线程处理请求- 生产环境应设置
debug=False
4. 关键组件详解
4.1 模型加载优化
生产环境建议的改进:
-
模型缓存:使用
lru_cache
装饰器缓存模型from functools import lru_cache@lru_cache(maxsize=1) def load_model():...
-
多模型支持:通过参数动态加载不同模型
def load_model(model_name='resnet18'):if model_name == 'resnet18':model = models.resnet18()elif model_name == 'efficientnet':model = models.efficientnet_b0()...
4.2 高级预处理流程
支持多种输入类型的预处理:
def prepare_input(data):if isinstance(data, str): # 文件路径image = Image.open(data)elif isinstance(data, bytes): # 二进制数据image = Image.open(io.BytesIO(data))elif hasattr(data, 'read'): # 文件对象image = Image.open(data)else:raise ValueError("Unsupported input type")# 后续预处理逻辑...
4.3 增强型API设计
RESTful API最佳实践:
-
版本控制:
@app.route("/api/v1/predict", methods=["POST"])
-
请求验证:
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}def allowed_file(filename):return '.' in filename and \filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
-
速率限制:使用Flask-Limiter扩展
from flask_limiter import Limiter from flask_limiter.util import get_remote_addresslimiter = Limiter(app,key_func=get_remote_address,default_limits=["200 per day", "50 per hour"] )@app.route("/predict") @limiter.limit("10/minute") def predict():...
5. 服务测试方案
5.1 Python客户端测试
增强版测试脚本:
import requests
from PIL import Image
import matplotlib.pyplot as pltdef test_api(image_path, url='http://localhost:5012/predict'):# 显示测试图像img = Image.open(image_path)plt.imshow(img)plt.axis('off')plt.show()# 发送请求files = {'image': open(image_path, 'rb')}response = requests.post(url, files=files)# 解析结果if response.status_code == 200:results = response.json()if results['success']:print("预测结果:")for i, pred in enumerate(results['predictions']):print(f"{i+1}. 类别: {pred['label']}, 概率: {pred['probability']:.4f}")else:print("请求失败:", results.get('error', 'Unknown error'))else:print("HTTP错误:", response.status_code)
5.2 使用Postman测试
- 创建新请求,方法设置为POST
- 在Body中选择form-data
- 添加key为"image"的文件字段
- 选择测试图片并发送
5.3 性能测试
使用Locust进行负载测试:
-
创建
locustfile.py
:from locust import HttpUser, task, betweenclass ModelUser(HttpUser):wait_time = between(1, 3)@taskdef predict(self):with open('test_image.jpg', 'rb') as f:self.client.post("/predict", files={"image": f})
-
启动测试:
locust -f locustfile.py
6. 生产环境部署
6.1 WSGI服务器配置
Gunicorn推荐配置:
gunicorn --bind 0.0.0.0:5012 --workers 4 --threads 2 --timeout 120 app:app
参数说明:
--workers
: CPU核心数的2-4倍--threads
: 每个worker的线程数--timeout
: 请求超时时间(秒)
6.2 Nginx反向代理
示例Nginx配置:
server {listen 80;server_name yourdomain.com;location / {proxy_pass http://localhost:5012;proxy_set_header Host $host;proxy_set_header X-Real-IP $remote_addr;proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;# 文件上传大小限制client_max_body_size 10M;}
}
6.3 Docker容器化
完整Dockerfile示例:
FROM python:3.9-slimWORKDIR /app# 安装系统依赖
RUN apt-get update && apt-get install -y \libgl1-mesa-glx \&& rm -rf /var/lib/apt/lists/*# 安装Python依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt# 复制应用代码
COPY . .# 环境变量
ENV FLASK_ENV=production
ENV PORT=5012# 暴露端口
EXPOSE $PORT# 启动命令
CMD ["gunicorn", "--bind", "0.0.0.0:5012", "--workers", "4", "--threads", "2", "app:app"]
构建并运行:
docker build -t model-api .
docker run -d -p 5012:5012 --name my-model-api model-api
6.4 Kubernetes部署
基本部署示例:
apiVersion: apps/v1
kind: Deployment
metadata:name: model-api
spec:replicas: 3selector:matchLabels:app: model-apitemplate:metadata:labels:app: model-apispec:containers:- name: model-apiimage: your-registry/model-api:latestports:- containerPort: 5012resources:limits:cpu: "1"memory: "1Gi"requests:cpu: "500m"memory: "512Mi"
---
apiVersion: v1
kind: Service
metadata:name: model-api-service
spec:selector:app: model-apiports:- protocol: TCPport: 80targetPort: 5012type: LoadBalancer
7. 高级主题
7.1 模型热更新
实现不重启服务的模型更新:
from werkzeug.utils import secure_filename
import osUPLOAD_FOLDER = 'models'
ALLOWED_EXTENSIONS = {'pth'}@app.route('/update_model', methods=['POST'])
def update_model():if 'file' not in request.files:return jsonify({'success': False, 'error': 'No file part'})file = request.files['file']if file.filename == '':return jsonify({'success': False, 'error': 'No selected file'})if file and allowed_file(file.filename):filename = secure_filename(file.filename)filepath = os.path.join(UPLOAD_FOLDER, filename)file.save(filepath)# 加载新模型try:load_new_model(filepath)return jsonify({'success': True})except Exception as e:return jsonify({'success': False, 'error': str(e)})
7.2 批处理支持
@app.route('/batch_predict', methods=['POST'])
def batch_predict():if 'images[]' not in request.files:return jsonify({'success': False, 'error': 'No images provided'})files = request.files.getlist('images[]')results = []for file in files:try:image = Image.open(io.BytesIO(file.read()))image = prepare_image(image, (224, 224))with torch.no_grad():output = model(image)prob = F.softmax(output, dim=1)_, pred = torch.max(prob, 1)results.append({'filename': file.filename,'prediction': int(pred.item()),'probability': float(prob[0][pred].item())})except Exception as e:results.append({'filename': file.filename,'error': str(e)})return jsonify({'results': results})
7.3 监控与日志
集成Prometheus监控:
from prometheus_client import start_http_server, Counter, Histogram# 定义指标
REQUEST_COUNT = Counter('request_count', 'API Request Count',['method', 'endpoint', 'http_status']
)
REQUEST_LATENCY = Histogram('request_latency_seconds', 'Request latency',['endpoint']
)@app.before_request
def before_request():request.start_time = time.time()@app.after_request
def after_request(response):latency = time.time() - request.start_timeREQUEST_LATENCY.labels(request.path).observe(latency)REQUEST_COUNT.labels(request.method, request.path, response.status_code).inc()return response# 在应用启动时启动Prometheus客户端
start_http_server(8000)