1、添加配置
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 注册WebSocket处理器,并允许所有来源的连接(在生产环境中应限制来源)
registry.addHandler(new WebSocketHandler(), "/ws/[请求的地址]")
.setAllowedOrigins("*").addInterceptors(new WebSocketSecurityTokenInterceptor());
}
}
2、添加Handler对请求进行处理
@Component
@Slf4j
public class WebSocketHandler extends TextWebSocketHandler {
private static final CopyOnWriteArrayList<WebSocketSession> sessions = new CopyOnWriteArrayList<>();
private static final ScheduledExecutorService scheduledThreadPool = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());
private static final Map<String,Long> lastPongTimes = new ConcurrentHashMap<>();
private static final String PING = "ping";
private static final String PONG = "pong";
private static final String GET_DATA = "getData";
/**
* 检测心跳是否正常的周期时间
*/
private static final Integer heartbeatInterval = 30_000;
/**
* 检测客户端连接心跳保持时间是否超时的时间
*/
private static final Integer heartbeatTimeout = 60_000;
@Resource
private ReportDashboardService reportDashboardService;
// 使用Guava 弱引用缓存数据
private static final Cache<String, Object> CACHE = CacheBuilder.newBuilder().softValues().expireAfterWrite(3, TimeUnit.SECONDS).build();
@PostConstruct
public void init() {
scheduledThreadPool.scheduleWithFixedDelay(() -> {
try {
checkHeartbeats();
} catch (IOException e) {
throw new RuntimeException(e);
}
},heartbeatInterval,heartbeatInterval,TimeUnit.SECONDS);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// 广播消息给所有已连接的客户端
String payload = message.getPayload();
if (PING.equals(payload)) {
session.sendMessage(new TextMessage(PONG));
} else if (GET_DATA.equals(payload)) {
sendData(session);
} else {
sendDataByPayload(session,payload);
//broadcast(payload);
}
recordPong(session.getId());
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// 当新连接建立时添加到列表
sessions.add(session);
//session.sendMessage(new TextMessage(PONG));
recordPong(session.getId());
sendData(session);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
// 当连接关闭时从列表中移除
sessions.remove(session);
log.info("Connection closed.sessionId={},status={}",session.getId(),status);
}
private void sendData(WebSocketSession sess) {
try {
if (sess.isOpen()) {
String query = sess.getUri().getQuery();
String[] split = query.split("&");
ReqObj req = new ReqObj();
String path = "";
for (String s : split) {
String[] arr = s.split("=");
if ("path".equals(arr[0])) {
path = arr[1];
} else if ("deviceCode".equals(arr[0])) {
req.setDeviceCode(arr[1]);
} else if ("pointType".equals(arr[0])) {
req.setPointType(arr[1]);
} else if ("gap".equals(arr[0])) {
if (arr[1] != null) {
req.setGap(Integer.parseInt(arr[1]));
}
}
}
sess.sendMessage(new TextMessage(sendDataByPath(path,req)));
}
} catch (Exception e) {
System.err.println("Failed to send message: " + e.getMessage());
}
}
private boolean closeTimeoutSession(String sessionId) throws IOException {
WebSocketSession s = null;
for (WebSocketSession sess : sessions) {
if (sess.isOpen() && sess.getId().equals(sessionId)) {
sess.sendMessage(new TextMessage("当前连接1分钟内未发送心跳消息,即将关闭"));
s = sess;
}
}
log.info("关闭心跳超过的连接,sessionId={}",sessionId);
return s != null && sessions.remove(s);
}
private void recordPong(String sessionId) {
lastPongTimes.put(sessionId,System.currentTimeMillis());
}
private boolean isClientAlive(String sessionId) {
Long lastPongTime = lastPongTimes.get(sessionId);
if (lastPongTime == null){
return false;
}
return System.currentTimeMillis() - lastPongTime <= heartbeatTimeout;
}
private void checkHeartbeats() throws IOException {
log.info("开始检查连接的心跳是否超时......");
Set<Map.Entry<String, Long>> entries = lastPongTimes.entrySet();
for (Map.Entry<String, Long> entry : entries) {
String sessionId = entry.getKey();
log.info("sessionId = {}",sessionId);
if (!isClientAlive(sessionId)) {
closeTimeoutSession(sessionId);
}
}
}
private void sendDataByPayload(WebSocketSession sess,String payload){
try {
if (sess.isOpen()) {
ChartDto dto = null;
String cacheKey = null;
try {
cacheKey = MD5Util.encrypt(payload);
dto = JSON.parseObject(payload, ChartDto.class);
} catch (Exception e) {
log.error("将payload转为ChartDto对象失败");
}
if (dto == null) {
JSONObject jsonObject = JSON.parseObject(payload);
String path = jsonObject.getString("path");
ReqObj req = new ReqObj();
req.setGap(jsonObject.getIntValue("gap",0));
req.setDeviceCode(jsonObject.getString("deviceCode"));
req.setPointType(jsonObject.getString("pointType"));
String s = sendDataByPath(path, req);
sess.sendMessage(new TextMessage("{\"path\":\""+path+"\",\"data\":"+s+"}"));
} else {
if (reportDashboardService == null) {
reportDashboardService = SpringUtil.getBean(ReportDashboardService.class);
}
synchronized (Thread.currentThread()) {
Object data = CACHE.getIfPresent(cacheKey);
if (data == null) {
data = reportDashboardService.getChartData(dto);
if (data != null) {
CACHE.put(cacheKey,data);
}
}
String s = JSON.toJSONString(R.success(data,"success",dto.getId()));
sess.sendMessage(new TextMessage(s));
}
}
}
} catch (Exception e) {
System.err.println("Failed to send message: " + e.getMessage());
}
}
private String sendDataByPath(String path,ReqObj req) {
return "{}";
}
private void broadcast(String message) {
for (WebSocketSession sess : sessions) {
try {
if (sess.isOpen()) {
String data = "原始数据:"+message+",翻转后的数据:"+new StringBuilder(message).reverse();
sess.sendMessage(new TextMessage(data));
}
} catch (Exception e) {
System.err.println("Failed to send message: " + e.getMessage());
}
}
}
}
3、拦截器握手时进行校验token
@Getter
@Slf4j
@Component
public class WebSocketSecurityTokenInterceptor implements HandshakeInterceptor {
private TokenAcquireHandler tokenAcquireHandler;
private TokenAnalysisHandler tokenAnalysisHandler;
{
tokenAcquireHandler = SpringUtil.getOrDefault( TokenAcquireHandler.class, new DefaultTokenAcquireHandler() );
tokenAnalysisHandler = SpringUtil.getOrDefault( TokenAnalysisHandler.class, new DefaultTokenAnalysisHandler() );
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
// 放开的路径直接放行
if ( FilterContextHandler.getContext().isExclude() ) {
// 如果已经手动setContext 此处不再赋值Empty
if ( SecurityContextHandler.getContext() == null ) {
SecurityContextHandler.setContext( SecurityContext.EMPTY );
}
return true;
}
String token = getToken(request);
if ( !StringUtils.hasText( token ) ) {
throw new TokenNotFoundException( "token not found" );
}
UserDetails userDetails = tokenAnalysisHandler.analysisToken( token );
checkUserDetails( token, userDetails );
SecurityContextHandler.setContext( new SecurityContext( token, userDetails ) );
return true;
}
/**
* 校验用户信息
*/
private void checkUserDetails( String token, UserDetails userDetails ) {
// 解析的UserDetails不能为空
if ( userDetails == null ) {
throw new TokenAnalysisException( "token analysis userDetails cannot be empty" );
}
// 判断用户是否启用
if ( !userDetails.isEnabled() ) {
throw new TokenAnalysisException();
}
// 判断用户是否过期
if ( userDetails.isAccountNonExpired() ) {
throw new UserDetailsExpiredException();
}
// 判断用户是否锁定
if ( userDetails.isAccountNonLocked() ) {
throw new UserLockException();
}
// 判断Token是否过期
if ( userDetails.isCredentialsNonExpired( token ) ) {
throw new TokenExpiredException();
}
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
// 握手完成后进行一些初始化工作
//log.info("握手完成......");
}
private String getToken( ServerHttpRequest req ) {
List< String > headerList = req.getHeaders().get( HttpHeaders.AUTHORIZATION );
String token = CollectionUtils.isEmpty( headerList ) ? "" : headerList.get( 0 );
if ( StrUtil.isNotBlank( token ) ) {
// req.setAttribute( HttpHeaders.AUTHORIZATION, token );
return token;
}
List< String > cookies = req.getHeaders().get( HttpHeaders.COOKIE );
for (String cookieStr : Optional.ofNullable(cookies).orElse(Collections.emptyList())) {
HttpCookie cookie = parseAuthCookie(cookieStr);
if ( cookie != null ){
return cookie.getValue();
}
}
return null;
}
private HttpCookie parseAuthCookie(String cookieStr) {
if (!StringUtils.hasText(cookieStr)){
return null;
}
List<HttpCookie> cookieList = Arrays.stream(cookieStr.split(";")).map(this::parseCookie).filter(Objects::nonNull).collect(Collectors.toList());
for (HttpCookie cookie : cookieList) {
if ( HttpHeaders.AUTHORIZATION.equals( cookie.getName() ) ) {
return cookie;
}
}
return null;
}
private HttpCookie parseCookie(String cookieStr) {
try {
List<HttpCookie> cookies = HttpCookie.parse(cookieStr);
return CollectionUtils.isEmpty(cookies) ? null : cookies.get(0);
}catch (Exception e){
return null;
}
}
}