当前位置: 首页 > news >正文

Springboot中添加原生websocket支持

1、添加配置

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        // 注册WebSocket处理器,并允许所有来源的连接(在生产环境中应限制来源)
        registry.addHandler(new WebSocketHandler(), "/ws/[请求的地址]")
                .setAllowedOrigins("*").addInterceptors(new WebSocketSecurityTokenInterceptor());
    }
}

2、添加Handler对请求进行处理

@Component
@Slf4j
public class WebSocketHandler extends TextWebSocketHandler {

    private static final CopyOnWriteArrayList<WebSocketSession> sessions = new CopyOnWriteArrayList<>();

    private static final ScheduledExecutorService scheduledThreadPool = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());

    private static final Map<String,Long> lastPongTimes = new ConcurrentHashMap<>();

    private static final String PING = "ping";

    private static final String PONG = "pong";

    private static final String GET_DATA = "getData";
    /**
     * 检测心跳是否正常的周期时间
     */
    private static final Integer heartbeatInterval = 30_000;

    /**
     * 检测客户端连接心跳保持时间是否超时的时间
     */
    private static final Integer heartbeatTimeout = 60_000;

    @Resource
    private ReportDashboardService reportDashboardService;
    // 使用Guava 弱引用缓存数据
    private static final Cache<String, Object> CACHE = CacheBuilder.newBuilder().softValues().expireAfterWrite(3, TimeUnit.SECONDS).build();

    @PostConstruct
    public void init() {
        scheduledThreadPool.scheduleWithFixedDelay(() -> {
            try {
                checkHeartbeats();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        },heartbeatInterval,heartbeatInterval,TimeUnit.SECONDS);
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        // 广播消息给所有已连接的客户端
        String payload = message.getPayload();
        if (PING.equals(payload)) {
            session.sendMessage(new TextMessage(PONG));
        } else if (GET_DATA.equals(payload)) {
            sendData(session);
        } else {
            sendDataByPayload(session,payload);
            //broadcast(payload);
        }
        recordPong(session.getId());
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // 当新连接建立时添加到列表
        sessions.add(session);
        //session.sendMessage(new TextMessage(PONG));
        recordPong(session.getId());
        sendData(session);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        // 当连接关闭时从列表中移除
        sessions.remove(session);
        log.info("Connection closed.sessionId={},status={}",session.getId(),status);
    }

    private void sendData(WebSocketSession sess) {
        try {
            if (sess.isOpen()) {
                String query = sess.getUri().getQuery();
                String[] split = query.split("&");
                ReqObj req = new ReqObj();
                String path = "";
                for (String s : split) {
                    String[] arr = s.split("=");
                    if ("path".equals(arr[0])) {
                        path = arr[1];
                    } else if ("deviceCode".equals(arr[0])) {
                        req.setDeviceCode(arr[1]);
                    } else if ("pointType".equals(arr[0])) {
                        req.setPointType(arr[1]);
                    } else if ("gap".equals(arr[0])) {
                        if (arr[1] != null) {
                            req.setGap(Integer.parseInt(arr[1]));
                        }
                    }
                }
                sess.sendMessage(new TextMessage(sendDataByPath(path,req)));
            }
        } catch (Exception e) {
            System.err.println("Failed to send message: " + e.getMessage());
        }
    }

    private boolean closeTimeoutSession(String sessionId) throws IOException {
        WebSocketSession s = null;
        for (WebSocketSession sess : sessions) {
            if (sess.isOpen() && sess.getId().equals(sessionId)) {
                sess.sendMessage(new TextMessage("当前连接1分钟内未发送心跳消息,即将关闭"));
                s = sess;
            }
        }
        log.info("关闭心跳超过的连接,sessionId={}",sessionId);
        return s != null && sessions.remove(s);
    }

    private void recordPong(String sessionId) {
        lastPongTimes.put(sessionId,System.currentTimeMillis());
    }

    private boolean isClientAlive(String sessionId) {
        Long lastPongTime = lastPongTimes.get(sessionId);
        if (lastPongTime == null){
            return false;
        }
        return System.currentTimeMillis() - lastPongTime <= heartbeatTimeout;
    }

    private void checkHeartbeats() throws IOException {
        log.info("开始检查连接的心跳是否超时......");
        Set<Map.Entry<String, Long>> entries = lastPongTimes.entrySet();
        for (Map.Entry<String, Long> entry : entries) {
            String sessionId = entry.getKey();
            log.info("sessionId = {}",sessionId);
            if (!isClientAlive(sessionId)) {
                closeTimeoutSession(sessionId);
            }
        }
    }

    private void sendDataByPayload(WebSocketSession sess,String payload){
        try {
            if (sess.isOpen()) {
                ChartDto dto = null;
                String cacheKey = null;
                try {
                     cacheKey = MD5Util.encrypt(payload);
                    dto = JSON.parseObject(payload, ChartDto.class);
                } catch (Exception e) {
                    log.error("将payload转为ChartDto对象失败");
                }
                if (dto == null) {
                    JSONObject jsonObject = JSON.parseObject(payload);
                    String path = jsonObject.getString("path");
                    ReqObj req = new ReqObj();
                    req.setGap(jsonObject.getIntValue("gap",0));
                    req.setDeviceCode(jsonObject.getString("deviceCode"));
                    req.setPointType(jsonObject.getString("pointType"));
                    String s = sendDataByPath(path, req);
                    sess.sendMessage(new TextMessage("{\"path\":\""+path+"\",\"data\":"+s+"}"));
                } else {
                    if (reportDashboardService == null) {
                        reportDashboardService = SpringUtil.getBean(ReportDashboardService.class);
                    }
                    synchronized (Thread.currentThread()) {
                        Object data = CACHE.getIfPresent(cacheKey);
                        if (data == null) {
                            data = reportDashboardService.getChartData(dto);
                            if (data != null) {
                                CACHE.put(cacheKey,data);
                            }
                        }
                        String s = JSON.toJSONString(R.success(data,"success",dto.getId()));
                        sess.sendMessage(new TextMessage(s));
                    }
                }
            }
        } catch (Exception e) {
            System.err.println("Failed to send message: " + e.getMessage());
        }
    }

    private String sendDataByPath(String path,ReqObj req) {
        return "{}";
    }

    private void broadcast(String message) {
        for (WebSocketSession sess : sessions) {
            try {
                if (sess.isOpen()) {
                    String data = "原始数据:"+message+",翻转后的数据:"+new StringBuilder(message).reverse();
                    sess.sendMessage(new TextMessage(data));
                }
            } catch (Exception e) {
                System.err.println("Failed to send message: " + e.getMessage());
            }
        }
    }
}

3、拦截器握手时进行校验token

@Getter
@Slf4j
@Component
public class WebSocketSecurityTokenInterceptor implements HandshakeInterceptor {

    private TokenAcquireHandler tokenAcquireHandler;

    private TokenAnalysisHandler tokenAnalysisHandler;

    {
        tokenAcquireHandler = SpringUtil.getOrDefault( TokenAcquireHandler.class, new DefaultTokenAcquireHandler() );
        tokenAnalysisHandler = SpringUtil.getOrDefault( TokenAnalysisHandler.class, new DefaultTokenAnalysisHandler() );
    }

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        //  放开的路径直接放行
        if ( FilterContextHandler.getContext().isExclude() ) {
            //  如果已经手动setContext 此处不再赋值Empty
            if ( SecurityContextHandler.getContext() == null ) {
                SecurityContextHandler.setContext( SecurityContext.EMPTY );
            }
            return true;
        }
        String token = getToken(request);
        if ( !StringUtils.hasText( token ) ) {
            throw new TokenNotFoundException( "token not found" );
        }
        UserDetails userDetails = tokenAnalysisHandler.analysisToken( token );
        checkUserDetails( token, userDetails );
        SecurityContextHandler.setContext( new SecurityContext( token, userDetails ) );
        return true;
    }

    /**
     * 校验用户信息
     */
    private void checkUserDetails( String token, UserDetails userDetails ) {
        //  解析的UserDetails不能为空
        if ( userDetails == null ) {
            throw new TokenAnalysisException( "token analysis userDetails cannot be empty" );
        }
        //  判断用户是否启用
        if ( !userDetails.isEnabled() ) {
            throw new TokenAnalysisException();
        }
        //  判断用户是否过期
        if ( userDetails.isAccountNonExpired() ) {
            throw new UserDetailsExpiredException();
        }
        //  判断用户是否锁定
        if ( userDetails.isAccountNonLocked() ) {
            throw new UserLockException();
        }
        //  判断Token是否过期
        if ( userDetails.isCredentialsNonExpired( token ) ) {
            throw new TokenExpiredException();
        }
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        // 握手完成后进行一些初始化工作
        //log.info("握手完成......");
    }

    private String getToken( ServerHttpRequest req ) {
        List< String > headerList = req.getHeaders().get( HttpHeaders.AUTHORIZATION );
        String token = CollectionUtils.isEmpty( headerList ) ? "" : headerList.get( 0 );
        if ( StrUtil.isNotBlank( token ) ) {
//            req.setAttribute( HttpHeaders.AUTHORIZATION, token );
            return token;
        }
        List< String > cookies = req.getHeaders().get( HttpHeaders.COOKIE );
        for (String cookieStr : Optional.ofNullable(cookies).orElse(Collections.emptyList())) {
            HttpCookie cookie = parseAuthCookie(cookieStr);
            if ( cookie != null ){
                return cookie.getValue();
            }
        }
        return null;
    }

    private HttpCookie parseAuthCookie(String cookieStr) {
        if (!StringUtils.hasText(cookieStr)){
            return null;
        }
        List<HttpCookie> cookieList = Arrays.stream(cookieStr.split(";")).map(this::parseCookie).filter(Objects::nonNull).collect(Collectors.toList());
        for (HttpCookie cookie : cookieList) {
            if ( HttpHeaders.AUTHORIZATION.equals( cookie.getName() ) ) {
                return cookie;
            }
        }
        return null;
    }

    private HttpCookie parseCookie(String cookieStr) {
        try {
            List<HttpCookie> cookies = HttpCookie.parse(cookieStr);
            return CollectionUtils.isEmpty(cookies) ? null : cookies.get(0);
        }catch (Exception e){
            return null;
        }
    }
}

相关文章:

  • 考研操作系统----操作系统的概念定义功能和目标(仅仅作为王道哔站课程讲义作用)
  • 蓝桥杯之图
  • web前端第三次作业
  • mysql用户名怎么看
  • H5自适应响应式代理记账与财政咨询服务类PbootCMS网站模板 – HTML5财务会计类网站源码下载
  • 【设计模式】02-理解常见设计模式-结构型模式
  • 一种微波场刺激器系统介绍
  • Molecular Communication(分子通信)与 Molecular Semantic Communication(分子语义通信)
  • 跟着李沐老师学习深度学习(十一)
  • 【LLM强化学习】LLM 强化学习中 Critic 模型训练详解
  • 基于逻辑概率的语义信道容量(Semantic Channel Capacity)和语义压缩理论(Semantic Compression Theory)
  • HTTP 请求方式`application/x-www-form-urlencoded` 与 `application/json` 怎么用?有什么区别?
  • 轻量级在线ETL数据集成工具架构设计与技术实现深度剖析
  • 网页五子棋——通用模块
  • leetcode:627. 变更性别(SQL解法)
  • WEB安全--SQL注入--INTO OUTFILE
  • 学习星开源在线考试教育系统
  • 在项目中操作 MySQL
  • UE WebUI插件依赖插件JsonLibrary 插件使用笔记
  • 「软件设计模式」适配器模式
  • 保健品 东莞网站建设/b2b平台推广
  • 依靠百度云做视频网站/站长统计app
  • 公司网站制作费用多少/seo公司优化
  • 泰安软件公司 泰安网站建设/线下推广团队
  • 平面设计和网页设计哪个工资高/衡阳seo排名
  • 建站最便宜的平台/深圳百度关键词排名