FastAPI WebSocket 由浅入深的开发范例
引言
在AI的开发中,WebSocket的开发尤其重要,比如与大模型的对接,一般都是使用WebSocket通讯,达到全双工与实时响应的效果。FastAPI 作为现代化的 Python Web 框架,提供了强大而简洁的 WebSocket 支持。本文将由浅入深,通过几个范例讲解,逐步掌握FastAPI的WebSocket的开发技巧。
1 什么是 WebSocket?
WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议。与传统的 HTTP 请求-响应模式不同,WebSocket 允许服务器和客户端之间建立持久连接,实现实时数据交换。
2 FastAPI的WebSocket基础用法。
from fastapi import FastAPI, WebSocketapp = FastAPI()@app.websocket("/ws")
async def simple_websocket(websocket: WebSocket):# 接受连接await websocket.accept()try:while True:# 接收消息data = await websocket.receive_text()# 发送回应await websocket.send_text(f"Echo: {data}")except Exception as e:print(f"Connection closed: {e}")
这个基础的范例展示了 WebSocket 的基本工作流程:建立连接、持续通信、处理断开。
3 多消息处理用法
from fastapi import WebSocket, WebSocketDisconnect@app.websocket("/ws/advanced")
async def advanced_websocket(websocket: WebSocket):await websocket.accept()try:while True:# 接收多种类型的消息message = await websocket.receive()if "text" in message:await websocket.send_text(f"Text received: {message['text']}")elif "bytes" in message:await websocket.send_bytes(message["bytes"])elif "json" in message:await websocket.send_json({"echo": message["json"]})except WebSocketDisconnect:print("Client disconnected gracefully")
4 使用连接管理器实现多连接管理
4.1 连接管理器的实现
在实际开发中,需要管理多个客户端的连接,因此有必要通过一个连接管理器来管理这些连接。
from typing import Dict, List
import jsonclass ConnectionManager:def __init__(self):self.active_connections: Dict[str, WebSocket] = {}self.connection_groups: Dict[str, List[str]] = {}async def connect(self, websocket: WebSocket, client_id: str):await websocket.accept()self.active_connections[client_id] = websocketdef disconnect(self, client_id: str):self.active_connections.pop(client_id, None)# 从所有群组中移除for group in self.connection_groups.values():if client_id in group:group.remove(client_id)async def send_personal_message(self, message: dict, client_id: str):if client_id in self.active_connections:await self.active_connections[client_id].send_json(message)async def broadcast(self, message: dict):disconnected_clients = []for client_id, websocket in self.active_connections.items():try:await websocket.send_json(message)except Exception:disconnected_clients.append(client_id)# 清理断开的连接for client_id in disconnected_clients:self.disconnect(client_id)async def add_to_group(self, group_name: str, client_id: str):if group_name not in self.connection_groups:self.connection_groups[group_name] = []if client_id not in self.connection_groups[group_name]:self.connection_groups[group_name].append(client_id)async def send_to_group(self, group_name: str, message: dict):if group_name in self.connection_groups:disconnected_clients = []for client_id in self.connection_groups[group_name]:if client_id in self.active_connections:try:await self.active_connections[client_id].send_json(message)except Exception:disconnected_clients.append(client_id)# 清理断开的连接for client_id in disconnected_clients:self.disconnect(client_id)manager = ConnectionManager()
4.2 连接管理器的使用
@app.websocket("/ws/chat/{client_id}")
async def chat_websocket(websocket: WebSocket, client_id: str):await manager.connect(websocket, client_id)try:# 通知所有用户新用户加入await manager.broadcast({"type": "user_joined","client_id": client_id,"message": f"User {client_id} joined the chat"})while True:data = await websocket.receive_text()message_data = json.loads(data)if message_data.get("type") == "join_room":# 加入聊天室room_name = message_data["room_name"]await manager.add_to_group(room_name, client_id)await manager.send_to_group(room_name, {"type": "room_join","client_id": client_id,"room_name": room_name})else:# 广播消息await manager.broadcast({"type": "message","client_id": client_id,"content": message_data.get("content", ""),"timestamp": message_data.get("timestamp")})except WebSocketDisconnect:manager.disconnect(client_id)await manager.broadcast({"type": "user_left","client_id": client_id,"message": f"User {client_id} left the chat"})
5 多线线程安全与异步处理
在真实的生成环境中,仅仅处理多连接还不行,还需要考虑多线程与异步的处理。以下范例的连接管理器增加了安全锁与发送队列,其他线程发送消息先放到发送队列即可,由单独的消息分发任务进行发送。
5.1 线程安全与异步处理的连接管理器
import asyncio
import threading
from typing import Dict
from queue import Queue, Emptyclass ThreadSafeWebSocketManager:def __init__(self):self.connections: Dict[str, WebSocket] = {}self.message_queues: Dict[str, Queue] = {}self.lock = threading.RLock()self.dispatcher_tasks: Dict[str, asyncio.Task] = {}async def add_connection(self, client_id: str, websocket: WebSocket):"""添加连接并启动消息分发器"""with self.lock:self.connections[client_id] = websocketself.message_queues[client_id] = Queue()# 启动消息分发任务task = asyncio.create_task(self._message_dispatcher(client_id))self.dispatcher_tasks[client_id] = taskdef remove_connection(self, client_id: str):"""移除连接"""with self.lock:websocket = self.connections.pop(client_id, None)queue = self.message_queues.pop(client_id, None)task = self.dispatcher_tasks.pop(client_id, None)# 取消任务if task:task.cancel()# 关闭 WebSocketif websocket:try:asyncio.create_task(websocket.close())except:passdef send_message(self, client_id: str, message: dict):"""从任何线程安全发送消息"""with self.lock:if client_id in self.message_queues:self.message_queues[client_id].put(message)def broadcast(self, message: dict, exclude_clients: set = None):"""广播消息到所有连接"""exclude_clients = exclude_clients or set()with self.lock:for client_id in self.connections:if client_id not in exclude_clients:self.send_message(client_id, message)async def _message_dispatcher(self, client_id: str):"""异步消息分发器"""while client_id in self.connections:try:# 使用异步方式等待消息message = await asyncio.get_event_loop().run_in_executor(None, self._get_message_safe, client_id)if message and client_id in self.connections:websocket = self.connections[client_id]await websocket.send_json(message)except Exception as e:print(f"Error in dispatcher for {client_id}: {e}")breakdef _get_message_safe(self, client_id: str):"""安全地从队列获取消息"""try:return self.message_queues[client_id].get(timeout=0.1)except Empty:return Nonethread_safe_manager = ThreadSafeWebSocketManager()
此连接管理器中,每个WebSocket连接成功后都启动了一个消息分发任务,专门发送该连接的消息。
5.2 连接管理器的使用
import time
import threading
from datetime import datetimedef start_background_notifications(manager: ThreadSafeWebSocketManager):"""启动后台通知任务"""def notification_generator():"""生成系统通知"""count = 0while True:try:notification = {"type": "system_notification","message": f"System update #{count}","timestamp": datetime.now().isoformat(),"priority": "info"}# 安全地广播通知manager.broadcast(notification)count += 1time.sleep(30) # 每30秒发送一次except Exception as e:print(f"Notification generator error: {e}")time.sleep(5) # 错误后等待5秒重试# 启动后台线程thread = threading.Thread(target=notification_generator, daemon=True)thread.start()@app.on_event("startup")
async def startup_event():start_background_notifications(thread_safe_manager)@app.websocket("/ws/thread-safe/{client_id}")
async def thread_safe_websocket(websocket: WebSocket, client_id: str):await thread_safe_manager.add_connection(client_id, websocket)try:# 发送欢迎消息thread_safe_manager.send_message(client_id, {"type": "welcome","message": "Connected to thread-safe WebSocket","timestamp": datetime.now().isoformat()})# 处理客户端消息while True:data = await websocket.receive_text()print(f"Received from {client_id}: {data}")except Exception as e:print(f"WebSocket error for {client_id}: {e}")finally:thread_safe_manager.remove_connection(client_id)
此范例中,启动一个定时广播的任务,每隔一段时间发送广播消息给每个WebSocket连接,最终也是调用send_message把消息放到发送队列里,由消息分发任务来发送消息。
6 双队列异步处理
当然我们也可以增加接收队列,由单独的接收任务来接收与处理消息,WebSocket的主线程仅仅是建立连接与定时发送心跳消息。
6.1 双队列异步连接管理器
import asyncio
import json
from typing import Dict
from fastapi import FastAPI, WebSocket, WebSocketDisconnectapp = FastAPI()class SimpleDualQueueManager:"""简单的双队列 WebSocket 管理器"""def __init__(self):# 连接存储self.connections: Dict[str, WebSocket] = {}# 双队列系统:接收队列和发送队列self.receive_queues: Dict[str, asyncio.Queue] = {}self.send_queues: Dict[str, asyncio.Queue] = {}# 任务存储self.receive_tasks: Dict[str, asyncio.Task] = {}self.process_tasks: Dict[str, asyncio.Task] = {}self.send_tasks: Dict[str, asyncio.Task] = {}async def add_connection(self, client_id: str, websocket: WebSocket):"""添加连接并启动三个核心任务"""self.connections[client_id] = websocketself.receive_queues[client_id] = asyncio.Queue()self.send_queues[client_id] = asyncio.Queue()# 启动三个异步任务self.receive_tasks[client_id] = asyncio.create_task(self._receive_messages(client_id, websocket))self.process_tasks[client_id] = asyncio.create_task(self._process_messages(client_id))self.send_tasks[client_id] = asyncio.create_task(self._send_messages(client_id, websocket))# 发送欢迎消息(通过发送队列)await self.send_queues[client_id].put({"type": "welcome","message": f"Client {client_id} connected"})async def remove_connection(self, client_id: str):"""移除连接并清理资源"""# 取消所有任务for task in [self.receive_tasks.get(client_id), self.process_tasks.get(client_id), self.send_tasks.get(client_id)]:if task:task.cancel()# 清理资源self.connections.pop(client_id, None)self.receive_queues.pop(client_id, None)self.send_queues.pop(client_id, None)self.receive_tasks.pop(client_id, None)self.process_tasks.pop(client_id, None)self.send_tasks.pop(client_id, None)async def _receive_messages(self, client_id: str, websocket: WebSocket):"""任务1: 接收消息并放入接收队列"""try:while True:# 从WebSocket接收消息data = await websocket.receive_text()message = json.loads(data)# 放入接收队列if client_id in self.receive_queues:await self.receive_queues[client_id].put(message)except WebSocketDisconnect:print(f"Client {client_id} disconnected")except Exception as e:print(f"Receive error for {client_id}: {e}")finally:await self.remove_connection(client_id)async def _process_messages(self, client_id: str):"""任务2: 从接收队列处理消息,结果放入发送队列"""try:while client_id in self.receive_queues:# 从接收队列获取消息message = await self.receive_queues[client_id].get()# 处理消息(这里简单回声)response = {"type": "echo","original": message,"timestamp": "now"}# 将响应放入发送队列if client_id in self.send_queues:await self.send_queues[client_id].put(response)except Exception as e:print(f"Process error for {client_id}: {e}")async def _send_messages(self, client_id: str, websocket: WebSocket):"""任务3: 从发送队列取出消息并发送"""try:while client_id in self.send_queues:# 从发送队列获取消息message = await self.send_queues[client_id].get()# 通过WebSocket发送await websocket.send_json(message)except Exception as e:print(f"Send error for {client_id}: {e}")def send_message(self, client_id: str, message: dict):"""从外部线程安全发送消息"""if client_id in self.send_queues:# 使用线程安全的方式将消息放入队列asyncio.run_coroutine_threadsafe(self.send_queues[client_id].put(message),asyncio.get_event_loop()`在这里插入代码片`)# 创建管理器实例
manager = SimpleDualQueueManager()
6.2 连接管理器的使用
@app.websocket("/ws/simple/{client_id}")
async def simple_websocket(websocket: WebSocket, client_id: str):await websocket.accept()# 注册连接到管理器await manager.add_connection(client_id, websocket)try:# 主循环只负责保持连接# 实际的消息处理已经在后台任务中运行while True:# 简单的心跳检查await asyncio.sleep(30)# 检查连接是否仍然有效if client_id not in manager.connections:breakexcept Exception as e:print(f"WebSocket error for {client_id}: {e}")finally:await manager.remove_connection(client_id)# 后台任务示例
import threading
import timedef background_task():"""模拟后台任务发送消息"""count = 0while True:try:count += 1message = {"type": "background","count": count,"timestamp": time.time()}# 向所有连接的客户端广播消息for client_id in list(manager.connections.keys()):manager.send_message(client_id, message)time.sleep(5) # 每5秒发送一次except Exception as e:print(f"Background task error: {e}")time.sleep(1)# 启动后台任务
@app.on_event("startup")
async def startup():thread = threading.Thread(target=background_task, daemon=True)thread.start()if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000)
说明:
三个独立任务:
· 接收任务:从 WebSocket 接收消息 → 放入接收队列
· 处理任务:从接收队列取出消息 → 处理 → 放入发送队列
· 发送任务:从发送队列取出消息 → 通过 WebSocket 发送
队列优势:
· 解耦:接收、处理、发送相互独立,互不阻塞
· 缓冲:处理速度不一致时,队列起到缓冲作用
· 线程安全:外部线程可以通过队列安全地发送消息
异步处理:
· 所有操作都是异步的,不会阻塞事件循环
· 使用 asyncio.Queue 实现异步队列
· 每个客户端有自己独立的队列和任务
这个简单范例包含了双队列异步处理的核心思想,可以根据需要扩展更复杂的功能如错误处理、优先级队列、批处理等。