10 分钟用 FastAPI 将机器学习模型上线为 REST API
如果你喜欢搭建机器学习模型并尝试新东西,那很酷——但说实话,只有当你把它提供给别人使用时,它才真正对他人有用。为此,你需要把它“服务化”——通过一个 Web API 暴露出来,让其他程序(或人类)可以发送数据并获得预测结果。这正是 REST API 的用武之地。
在本文中,你将学到如何用不到 10 分钟的时间,从一个简单的机器学习模型,走到一个可用于生产的 API。我们使用 FastAPI——Python 里速度快、对开发者非常友好的 Web 框架。而且我们不会停留在“能跑就行”的演示,我们还会加入这些能力:
- 校验入参数据
- 记录每一次请求日志
- 使用后台任务避免请求变慢
- 优雅地处理错误
在写代码之前,先快速看看项目结构长什么样:
项目结构
ml-api/
│
├── model/
│ └── train_model.py # Script to train and save the model
│ └── iris_model.pkl # Trained model file
│
├── app/
│ └── main.py # FastAPI app
│ └── schema.py # Input data schema using Pydantic
│
├── requirements.txt # All dependencies
└── README.md # Optional documentation
步骤 1:安装依赖
- 安装 FastAPI、Uvicorn、Scikit-learn、joblib、pydantic
- 保存依赖到 requirements.txt
pip install fastapi uvicorn scikit-learn joblib pydantic
pip freeze > requirements.txt
步骤 2:训练并保存一个简单模型
- 使用 Iris 数据集
- 训练随机森林分类器
- 使用 joblib 保存模型
# model/train_model.py
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import joblib, osX, y = load_iris(return_X_y=True)
clf = RandomForestClassifier()
clf.fit(*train_test_split(X, y, test_size=0.2, random_state=42)[:2])os.makedirs("model", exist_ok=True)
joblib.dump(clf, "model/iris_model.pkl")
print("✅ Model saved to model/iris_model.pkl")
运行脚本以生成模型文件:
python model/train_model.py
步骤 3:定义 API 输入数据结构
- 使用 Pydantic 定义并校验输入
- 限制四个浮点数在 (0, 10) 之间
# app/schema.py
from pydantic import BaseModel, Fieldclass IrisInput(BaseModel):sepal_length: float = Field(..., gt=0, lt=10)sepal_width: float = Field(..., gt=0, lt=10)petal_length: float = Field(..., gt=0, lt=10)petal_width: float = Field(..., gt=0, lt=10)
步骤 4:创建 FastAPI 应用
- 启动时加载模型
- 定义 /predict 接口
- 返回预测类别与概率
- 使用后台任务记录日志
- 捕获异常并返回友好错误
# app/main.py
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from app.schema import IrisInput
import numpy as np, joblib, logging# Load the model
model = joblib.load("model/iris_model.pkl")# Set up logging
logging.basicConfig(filename="api.log", level=logging.INFO,format="%(asctime)s - %(message)s")# Create the FastAPI app
app = FastAPI()@app.post("/predict")
def predict(input_data: IrisInput, background_tasks: BackgroundTasks):try:# Format the input as a NumPy arraydata = np.array([[input_data.sepal_length,input_data.sepal_width,input_data.petal_length,input_data.petal_width]])# Run predictionpred = model.predict(data)[0]proba = model.predict_proba(data)[0]species = ["setosa", "versicolor", "virginica"][pred]# Log in the background so it doesn’t block responsebackground_tasks.add_task(log_request, input_data, species)# Return prediction and probabilitiesreturn {"prediction": species,"class_index": int(pred),"probabilities": {"setosa": float(proba[0]),"versicolor": float(proba[1]),"virginica": float(proba[2])}}except Exception as e:logging.exception("Prediction failed")raise HTTPException(status_code=500, detail="Internal error")# Background logging task
def log_request(data: IrisInput, prediction: str):logging.info(f"Input: {data.dict()} | Prediction: {prediction}")
步骤 5:运行 API 并在线测试
- 启动服务
- 打开交互式文档调试(Swagger UI)
uvicorn app.main:app --reload
在浏览器访问:
http://127.0.0.1:8000/docs
示例请求(Swagger UI 或任何 HTTP 客户端):
{"sepal_length": 6.1,"sepal_width": 2.8,"petal_length": 4.7,"petal_width": 1.2
}
或使用 curl:
curl -X POST "http://127.0.0.1:8000/predict" \-H "Content-Type: application/json" \-d '{"sepal_length": 6.1,"sepal_width": 2.8,"petal_length": 4.7,"petal_width": 1.2}'
示例响应:
{"prediction": "versicolor","class_index": 1,"probabilities": {"setosa": 0.0,"versicolor": 1.0,"virginica": 0.0}
}
可选:部署到线上
- Render.com(零配置)
- Railway.app(CI 友好)
- Heroku(Docker 部署)
进阶生产化建议
- 接口鉴权:API Key / OAuth
- 监控与告警:Prometheus + Grafana
- 异步任务队列:Redis / Celery
总结
- 自动输入校验
- 可读的预测结果与概率
- 每次请求都有日志
- 后台任务不阻塞主流程
- 健壮的异常处理
- 全部不足 100 行代码