当前位置: 首页 > news >正文

【django】模型部署过程

模型部署

  • 示例:保存 Scikit-learn 模型
  • myapp/views.py
  • 全局加载模型
  • tasks.py(Celery任务)
  • views.py 修改为异步调用
  • views.py

  1. 准备工作
    模型保存格式
    确保你的模型已保存为可加载的格式:
    ● TensorFlow/Keras:.h5 或 SavedModel 格式
    ● PyTorch:.pt 或 .pth 文件
    ● Scikit-learn:使用 joblib 或 pickle 保存(推荐 joblib)

示例:保存 Scikit-learn 模型

from sklearn.ensemble import RandomForestClassifier
import joblib

model = RandomForestClassifier()
model.fit(X_train, y_train)
joblib.dump(model, ‘my_model.joblib’)

  1. 项目结构规划
    建议的 Django 项目结构:
    myproject/
    ├── myapp/
    │ ├── models/ # 存放模型文件
    │ │ └── my_model.joblib
    │ ├── views.py # 处理请求和模型调用
    │ ├── urls.py # 定义API路由
    │ └── …
    ├── myproject/
    │ ├── settings.py
    │ └── urls.py # 主路由
    └── manage.py

  2. 模型加载与初始化
    在 Django 中全局加载模型
    在 myapp/apps.py 或 views.py 中初始化模型,避免每次请求重复加载。

myapp/views.py

from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
import joblib
import os

全局加载模型

model_path = os.path.join(os.path.dirname(file), ‘models/my_model.joblib’)
model = joblib.load(model_path)

@csrf_exempt # 若需跨域访问可临时禁用CSRF(生产环境需谨慎)
def predict(request):
if request.method == ‘POST’:
try:
# 获取输入数据(假设发送JSON)
data = json.loads(request.body)
features = data[‘features’]

        # 调用模型预测
        prediction = model.predict([features])[0]
        
        return JsonResponse({'prediction': prediction})
    except Exception as e:
        return JsonResponse({'error': str(e)}, status=400)
return JsonResponse({'error': '仅支持POST请求'}, status=405)
  1. 配置路由
    在 myapp/urls.py 中添加API路由
    from django.urls import path
    from . import views

urlpatterns = [
path(‘predict/’, views.predict, name=‘predict’),
]
在项目主路由 myproject/urls.py 中引入
from django.urls import include, path

urlpatterns = [
path(‘api/’, include(‘myapp.urls’)),
]

  1. 测试API
    使用 curl 或 Postman 发送POST请求测试:
    curl -X POST http://localhost:8000/api/predict/
    -H “Content-Type: application/json”
    -d ‘{“features”: [1.2, 3.4, 5.6]}’
    预期响应:
    {“prediction”: 0}

  2. 高级优化
    异步处理(Celery + Redis)
    如果模型推理耗时较长,可用 Celery 异步任务避免阻塞请求:

tasks.py(Celery任务)

from celery import shared_task
from myapp.views import model # 复用全局加载的模型

@shared_task
def async_predict(features):
return model.predict([features])[0]

views.py 修改为异步调用

@csrf_exempt
def predict(request):
if request.method == ‘POST’:
data = json.loads(request.body)
task = async_predict.delay(data[‘features’])
return JsonResponse({‘task_id’: task.id}, status=202)
缓存模型输出
使用 Django 缓存减少重复计算:
from django.core.cache import cache

def predict(request):
data = json.loads(request.body)
features = tuple(data[‘features’]) # 转换为可哈希类型

# 检查缓存
if cache.has_key(features):
    return JsonResponse({'prediction': cache.get(features)})

# 计算并缓存
prediction = model.predict([features])[0]
cache.set(features, prediction, timeout=3600)  # 缓存1小时
return JsonResponse({'prediction': prediction})
  1. 关键注意事项
  2. 线程安全:
    from threading import Lock
    model_lock = Lock()

def predict(request):
with model_lock:
prediction = model.predict(…)
○ 如果模型非线程安全(如某些 TensorFlow 旧版本),需加锁或使用单例模式。
2. 性能优化:
○ 使用 gunicorn 或 uvicorn 替代 Django 自带的开发服务器。
○ 启用 GPU 加速(如 TensorFlow/PyTorch 的 GPU 版本)。
3. 输入验证:
def validate_features(features):
if len(features) != 3:
raise ValueError(“必须提供3个特征”)
if not all(isinstance(x, (int, float)) for x in features):
raise ValueError(“特征必须为数字”)
○ 严格校验前端传入的数据格式和范围,防止恶意输入。
4. 依赖管理:
tensorflow2.12.0
scikit-learn
1.2.2
joblib==1.2.0
○ 在 requirements.txt 中明确指定模型库版本:

完整示例:图像分类模型集成
假设有一个图像分类模型(如 ResNet),可按以下方式处理文件上传:

views.py

from django.core.files.storage import default_storage
from tensorflow.keras.preprocessing import image
import numpy as np

def predict_image(request):
if request.method == ‘POST’:
file = request.FILES[‘image’]
file_path = default_storage.save(‘tmp/’ + file.name, file)

    # 预处理图像
    img = image.load_img(file_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    
    # 预测
    prediction = model.predict(img_array)
    class_idx = np.argmax(prediction)
    
    return JsonResponse({'class': class_idx})

通过以上步骤,你可以将训练好的模型无缝集成到 Django 中,并通过 RESTful API 提供服务。根据实际需求调整代码结构和优化策略。

相关文章:

  • SpringMVC学习(入门案例思路及实现、Web容器初始化与SpringMVC配置类)(2)
  • GIT工具学习【2】:分支
  • Java 入门 (超级详细)
  • Unity 适用Canvas 为任一渲染模式的UI 拖拽
  • 2.css简介
  • pytorch 模型测试
  • 刷题记录10
  • 下载谷歌浏览器(Chrome)
  • HttpServletRequest 和 HttpServletResponse 不同JDK版本的引入
  • 23种设计模式之单例模式(Singleton Pattern)【设计模式】
  • 【三.大模型实战应用篇】【4.智能学员辅导系统:docx转PDF的自动化流程】
  • 基于springboot的丢失儿童的基因比对系统(源码+lw+部署文档+讲解),源码可白嫖!
  • SFP28(25 Gigabit Small Form-factor Pluggable)详解
  • STM32-FOC-SDK包含以下关键知识点
  • 算法基础 -- 字符串哈希的基本概念和数学原理分析
  • Linux常用指令学习笔记
  • 以1.7K深圳小区房价为例,浙大GIS实验室使用注意力机制挖掘地理情景特征,提升空间非平稳回归精度
  • 蓝桥与力扣刷题(蓝桥 k倍区间)
  • JavaScript 系列之:事件
  • 使用Docker搭建Oracle Database 23ai Free并扩展MAX_STRING_SIZE的完整指南
  • 国家统计局公布2024年城镇单位就业人员年平均工资情况
  • 韶关一企业将消防安装工程肢解发包,广东住建厅:罚款逾五万
  • 巴菲特最新调仓:一季度大幅抛售银行股,再现保密仓位
  • 淄博一酒店房间内被曝发现摄像头,当地警方已立案调查
  • 由我国牵头制定,适老化数字经济国际标准发布
  • 《歌手2025》公布首发阵容,第一期就要淘汰一人