gRPC从0到1系列【16】
文章目录
- 双向流式RPC (Bidirectional Streaming RPC)
- ✅ 6.2 示例代码
- 6.2.1 服务器端代码
- 6.2.2 代码解析
- 6.2.3 服务端启动代码
双向流式RPC (Bidirectional Streaming RPC)
✅ 6.2 示例代码
6.2.1 服务器端代码
package cn.tcmeta.chat.grpc;import io.grpc.stub.StreamObserver;import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;/*** @author: laoren* @description: gRPC双向流式RPC服务实现* @version: 1.0.0* 双向流式RPC服务实现* 特点: 客户端和服务器都可以发送流式消息*/
public class BidirectionalChatServiceImpl extends ChatServiceGrpc.ChatServiceImplBase {private final ConcurrentHashMap<String, StreamObserver<ChatResponse>> activeConnections = new ConcurrentHashMap<>();private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(10);private final AtomicLong connectionCounter = new AtomicLong(0);/*** 双向流式RPC服务实现* 特点: 客户端和服务器都可以发送流式消息** @param responseObserver StreamObserver* @return StreamObserver*/@Overridepublic StreamObserver<ChatMessage> bidirectionalChat(StreamObserver<ChatResponse> responseObserver) {long connectionId = connectionCounter.incrementAndGet();String connectionKey = "conn_" + connectionId;System.out.println("=== 双向流式RPC连接建立, 连接ID: " + connectionId + " ===");// 保存连接以便后续发送消息activeConnections.put(connectionKey, responseObserver);// 1. 发送欢迎消息sendWelcomeMessage(responseObserver, connectionId);// 2. 心跳检测startHeartbeat(connectionKey, connectionId);return new StreamObserver<>() {private long messageCount = 0;private final long startTime = System.currentTimeMillis();private String currentUserId = "";@Overridepublic void onNext(ChatMessage message) {messageCount++;long currentTime = System.currentTimeMillis();long elapsed = currentTime - startTime;System.out.println("双向流连接 " + connectionId + " - 收到第 " + messageCount + " 条消息");System.out.println("用户: " + message.getUserId());System.out.println("消息: " + message.getMessage());System.out.println("类型: " + message.getType());System.out.println("时间: +" + elapsed + "ms");// 记录当前用户IDif (currentUserId.isEmpty()) {currentUserId = message.getUserId();}try {processAndRespond(message, responseObserver, connectionId, messageCount);} catch (Exception e) {System.err.println("处理消息时发生错误: " + e.getMessage());sendErrorMessage(responseObserver, "处理失败: " + e.getMessage());}}@Overridepublic void onError(Throwable t) {System.err.println("双向流连接 " + connectionId + " 错误: " + t.getMessage());cleanupConnection(connectionKey, connectionId);}@Overridepublic void onCompleted() {long totalTime = System.currentTimeMillis() - startTime;System.out.println("=== 双向流连接 " + connectionId + " 关闭 ===");System.out.println("总共处理 " + (messageCount) + " 条消息");System.out.println("连接持续时间: " + totalTime + "ms");String averageMessageInterval = messageCount > 0 ?String.format("%.2f", totalTime / (double) messageCount) : "N/A";System.out.println("平均消息间隔: " + averageMessageInterval + "ms");// 发送告别消息sendGoodbyeMessage(responseObserver, connectionId, messageCount, totalTime);// 清理资源cleanupConnection(connectionKey, connectionId);}};}/*** 处理消息** @param message 消息内容* @param responseObserver responseObserver* @param connectionId 链接ID* @param messageCount 消息数量*/private void processAndRespond(ChatMessage message,StreamObserver<ChatResponse> responseObserver,long connectionId,long messageCount) {switch (message.getType()) {case TEXT:handlerTextMessage(message, responseObserver, connectionId, messageCount);break;case IMAGE:handleImageMessage(message, responseObserver, connectionId, messageCount);break;case FILE:handleFileMessage(message, responseObserver, connectionId, messageCount);break;case SYSTEM:handleSystemMessage(message, responseObserver, connectionId, messageCount);break;default:handleUnknownMessage(message, responseObserver, connectionId, messageCount);break;}}/*** 处理文本消息** @param message 消息内容* @param responseObserver responseObserver* @param connectionId 连接ID* @param messageCount 消息数量*/private void handlerTextMessage(ChatMessage message,StreamObserver<ChatResponse> responseObserver,long connectionId,long messageCount) {String responseText;if (message.getMessage().toLowerCase().contains("您好")) {responseText = "您好! 很高兴为您服务";} else if (message.getMessage().toLowerCase().contains("时间")) {responseText = "当前时间是: " + System.currentTimeMillis();} else if (message.getMessage().toLowerCase().contains("帮助")) {responseText = """以下是可用功能:1. 时间: 获取当前时间2. 帮助: 显示此帮助信息3. 回声: 输入任何内容,服务端将返回相同的内容4. 时间: 获取当前时间5. 帮助: 显示此帮助信息6. 回声: 输入任何内容,服务端将返回相同的内容""";} else {responseText = "已收到您的消息: \"" + message.getMessage() + "\"";}ChatResponse response = ChatResponse.newBuilder().setMessageId("TEXT_" + connectionId + "_" + messageCount).setStatus("PROCESSED").setResponseMessage(responseText).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.SUCCESS).build();responseObserver.onNext(response);}/*** 处理图片消息** @param message 图片消息* @param responseObserver responseObserver* @param connectionId 连接ID* @param messageCount 消息数量*/private void handleImageMessage(ChatMessage message,StreamObserver<ChatResponse> responseObserver,long connectionId, long messageCount) {// 模拟图片处理流程String[] processingSteps = {"开始分析图片内容","检测图片中的对象","生成图片描述","分析完成"};for (int i = 0; i < processingSteps.length; i++) {ChatResponse response = ChatResponse.newBuilder().setMessageId("IMG_" + connectionId + "_" + messageCount + "_" + (i + 1)).setStatus("PROCESSING").setResponseMessage(processingSteps[i] + " (" + ((i + 1) * 25) + "%)").setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();responseObserver.onNext(response);// 模拟处理延迟try {TimeUnit.MILLISECONDS.sleep(500);} catch (InterruptedException e) {e.printStackTrace();}// 上述流程处理完成之后,最终响应// 最终响应ChatResponse finalResponse = ChatResponse.newBuilder().setMessageId("IMG_" + connectionId + "_" + messageCount + "_COMPLETE").setStatus("COMPLETED").setResponseMessage("图片分析完成,检测到3个主要对象").setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.SUCCESS).build();responseObserver.onNext(finalResponse);}}/*** 处理文件消息** @param message 文件消息* @param responseObserver responseObserver* @param connectionId 连接ID* @param messageCount 消息数量*/private void handleFileMessage(ChatMessage message,StreamObserver<ChatResponse> responseObserver,long connectionId,long messageCount) {// 模拟文件处理for (int progress = 10; progress <= 100; progress += 10) {ChatResponse progressResponse = ChatResponse.newBuilder().setMessageId("FILE_" + connectionId + "_" + messageCount + "_" + progress).setStatus("UPLOADING").setResponseMessage("文件上传进度: " + progress + "%").setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();responseObserver.onNext(progressResponse);try {Thread.sleep(200);} catch (InterruptedException e) {Thread.currentThread().interrupt();break;}}}/*** 处理系统消息** @param message 系统消息* @param responseObserver responseObserver* @param connectionId 链接ID* @param messageCount 消息数量*/private void handleSystemMessage(ChatMessage message,StreamObserver<ChatResponse> responseObserver,long connectionId, long messageCount) {ChatResponse response = ChatResponse.newBuilder().setMessageId("SYS_" + connectionId + "_" + messageCount).setStatus("ACKNOWLEDGED").setResponseMessage("系统消息已确认: " + message.getMessage()).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.SUCCESS).build();responseObserver.onNext(response);}/*** 发送欢迎消息** @param responseObserver StreamObserver<ChatResponse> responseObserver* @param connectionId connectionId*/private void sendWelcomeMessage(StreamObserver<ChatResponse> responseObserver, long connectionId) {ChatResponse welcomeResponse = ChatResponse.newBuilder().setMessageId("WELCOME_" + connectionId) // 设置消息ID.setStatus("CONNECTED").setResponseMessage("欢迎来到双向流式RPC聊天室! 连接ID: " + connectionId).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();responseObserver.onNext(welcomeResponse);System.out.println("发送欢迎消息到连接` " + connectionId);}/*** 处理未知消息** @param message 未知消息* @param responseObserver responseObserver* @param connectionId 链接ID* @param messageCount 消息数量*/private void handleUnknownMessage(ChatMessage message,StreamObserver<ChatResponse> responseObserver,long connectionId, long messageCount) {ChatResponse response = ChatResponse.newBuilder().setMessageId("UNK_" + connectionId + "_" + messageCount).setStatus("WARNING").setResponseMessage("未知消息类型,已忽略").setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.WARNING).build();responseObserver.onNext(response);}/*** 发送随机通知** @param responseObserver StreamObserver* @param connectionId id*/private void sendRandomNotification(StreamObserver<ChatResponse> responseObserver,long connectionId) {String[] notifications = {"系统维护通知: 本周六凌晨2-4点进行系统维护","新功能提示: 已上线图片智能识别功能","活动通知: 周年庆活动即将开始,敬请期待","安全提醒: 请定期更新密码以保证账户安全","性能优化: 系统响应速度已提升20%"};String notification = notifications[(int) (Math.random() * notifications.length)];ChatResponse response = ChatResponse.newBuilder().setMessageId("NOTIFY_" + connectionId + "_" + System.currentTimeMillis()).setStatus("NOTIFICATION").setResponseMessage(notification).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();responseObserver.onNext(response);System.out.println("向连接 " + connectionId + " 发送通知: " + notification);}/*** 心跳检测机制** @param connectionKey 连接key* @param connectionId id*/private void startHeartbeat(String connectionKey, long connectionId) {scheduler.scheduleAtFixedRate(() -> {StreamObserver<ChatResponse> observer = activeConnections.get(connectionKey);ChatResponse heartbeatResponse = ChatResponse.newBuilder().setMessageId("HEARTBEAT_" + connectionId + "_" + System.currentTimeMillis()).setStatus("HEARTBEAT").setResponseMessage("连接状态正常").setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();observer.onNext(heartbeatResponse);System.out.println("向连接 " + connectionId + " 发送心跳");}, 10, 10, TimeUnit.SECONDS);}/*** 发送错误消息** @param responseObserver StreamObserver* @param errorMessage 异常消息内容*/public void sendErrorMessage(StreamObserver<ChatResponse> responseObserver, String errorMessage) {ChatResponse errorResponse = ChatResponse.newBuilder().setMessageId("ERROR_" + System.currentTimeMillis()).setStatus("ERROR").setResponseMessage(errorMessage).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.ERROR).build();responseObserver.onNext(errorResponse);}/*** 聊天结束, 发送统计信息** @param responseObserver responseObserver* @param connectionId 连接ID* @param messageCount 消息数量* @param duration 持续时间*/public void sendGoodbyeMessage(StreamObserver<ChatResponse> responseObserver,long connectionId,long messageCount,long duration) {System.out.println("=== 聊天结束, 统计信息 ===");String averageMessageInterval = messageCount > 0 ?String.format("%.2f", duration / (double) messageCount) : "N/A";ChatResponse goodbyeResponse = ChatResponse.newBuilder().setMessageId("GOODBYE_" + connectionId).setStatus("DISCONNECTED").setResponseMessage(String.format("感谢使用! 统计: %d条消息, %d秒时长, 平均%s秒/消息",messageCount, duration / 1000, averageMessageInterval)).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();responseObserver.onNext(goodbyeResponse);responseObserver.onCompleted();}/*** 清理连接资源** @param connectionKey 删除连接的key* @param connectionId 删除连接的ID*/public void cleanupConnection(String connectionKey, long connectionId) {activeConnections.remove(connectionKey);System.out.println("=== 连接已关闭, 删除连接 " + connectionId + " ===" + connectionKey);}/*** 向特定的连接发送消息** @param connectionKey 连接key* @param message 消息内容*/public void sendMessageToConnection(String connectionKey, String message) {StreamObserver<ChatResponse> observer = activeConnections.get(connectionKey);if (observer != null) {ChatResponse response = ChatResponse.newBuilder().setMessageId("BROADCAST_" + System.currentTimeMillis()).setStatus("BROADCAST").setResponseMessage(message).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();observer.onNext(response);System.out.println("向连接 " + connectionKey + " 发送广播消息: " + message);} else {System.out.println("连接 " + connectionKey + " 不存在");}}/*** 广播消息** @param message 消息内容*/public void broadcastMessage(String message) {System.out.println("开始广播消息到 " + activeConnections.size() + " 个连接: " + message);activeConnections.forEach((connectionKey, observer) -> {ChatResponse response = ChatResponse.newBuilder().setMessageId("BROADCAST_" + System.currentTimeMillis()).setStatus("BROADCAST").setResponseMessage(message).setTimestamp(System.currentTimeMillis()).setCode(ChatResponse.ResponseCode.INFO).build();observer.onNext(response);});System.out.println("广播消息结束");}/*** 获取连接统计信息*/public void printConnectionStatistics() {System.out.println("\n=== 双向流连接统计 ===");System.out.println("活跃连接数: " + activeConnections.size());activeConnections.keySet().forEach(key ->System.out.println("连接: " + key));}/*** 关闭资源*/public void shutdown() {scheduler.shutdown();try {if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) {scheduler.shutdownNow();}} catch (Exception e) {scheduler.shutdown();Thread.currentThread().interrupt();}}
}
6.2.2 代码解析
✅ 核心功能说明
- 基本架构与职责
- gRPC服务实现: 实现了 ChatServiceGrpc.ChatServiceImplBase,提供双向流式RPC通信服务
- 连接管理: 使用 ConcurrentHashMap 管理所有活跃的客户端连接,支持并发访问
- 连接标识: 通过原子计数器 connectionCounter 为每个连接分配唯一ID
- 核心功能模块
- 双向流式通信 (bidirectionalChat)
- 连接建立: 为每个客户端创建独立的流观察器 StreamObserver<ChatMessage>
- 欢迎消息: 连接建立后立即发送欢迎信息
- 心跳机制: 启动定时任务每10秒发送心跳包维持连接
- 消息处理: 接收并分类处理四种不同类型的消息
- 生命周期管理: 完整跟踪连接的建立、使用和关闭过程
-
消息处理系统
- 文本消息处理 (handlerTextMessage):
- 支持关键词响应(“您好”、“时间”、“帮助”)
- 提供回声功能和帮助文档
- 图片消息处理 (handleImageMessage):
- 模拟图片分析过程,分步骤返回处理进度
- 最终返回分析结果
- 文件消息处理 (handleFileMessage):
- 模拟文件上传进度,以10%递增显示上传状态
- 系统消息处理 (handleSystemMessage):
- 处理系统级别消息并返回确认
- 未知消息处理 (handleUnknownMessage):
- 对无法识别的消息类型进行统一处理
- 文本消息处理 (handlerTextMessage):
-
连接管理功能
- 资源清理 (cleanupConnection): 移除断开的连接,释放资源
- 统计信息 (sendGoodbyeMessage): 连接关闭时提供详细的会话统计数据
- 错误处理 (sendErrorMessage): 统一的错误消息发送机制
-
通知与广播系统
- 随机通知 (sendRandomNotification): 向客户端发送预定义的系统通知
- 定向消息 (sendMessageToConnection): 向指定连接发送消息
- 全局广播 (broadcastMessage): 向所有活跃连接广播消息
- 统计查询 (printConnectionStatistics): 显示当前连接状态统计
6.2.3 服务端启动代码
package cn.tcmeta.chat.grpc;import io.grpc.Server;
import io.grpc.ServerBuilder;import java.util.concurrent.TimeUnit;/*** @author: laoren* @description: 服务器端* @version: 1.0.0*/
public class SimpleChatServer {public static void main(String[] args) {ServerBuilder<?> builder = ServerBuilder.forPort(8080).addService(new BidirectionalChatServiceImpl());Server server = builder.build();try {server.start();System.out.println("✅✅✅✅ 服务启动成功, 监听端口号 8080~~~~");// 添加关闭钩子Runtime.getRuntime().addShutdownHook(new Thread(() -> {System.err.println("正在关闭gRPC服务器...");try {if (server != null) {server.shutdown().awaitTermination(30, TimeUnit.SECONDS);}} catch (InterruptedException e) {e.printStackTrace();}System.err.println("服务器已关闭");}));// 3. 等待服务server.awaitTermination();} catch (Exception e) {e.printStackTrace();}}
}