[netty5: WebSocketFrameEncoder WebSocketFrameDecoder]-源码解析
WebSocketFrameMaskGenerator
WebSocketFrameMaskGenerator 是用于客户端生成 WebSocket 帧掩码的接口,通过 nextMask() 返回一个 4 字节整数掩码,用于加密帧的 payload。
public interface WebSocketFrameMaskGenerator {int nextMask();
}
RandomWebSocketFrameMaskGenerator
RandomWebSocketFrameMaskGenerator 实现了掩码生成接口,使用 ThreadLocalRandom 生成随机的 4 字节整数作为 WebSocket 帧的掩码。
public final class RandomWebSocketFrameMaskGenerator implements WebSocketFrameMaskGenerator {public static final RandomWebSocketFrameMaskGenerator INSTANCE = new RandomWebSocketFrameMaskGenerator();private RandomWebSocketFrameMaskGenerator() {}@Overridepublic int nextMask() {return ThreadLocalRandom.current().nextInt();}
}
WebSocketFrameEncoder
WebSocketFrameEncoder
负责将 WebSocketFrame
编码为符合 WebSocket 协议格式的二进制数据帧,处理帧头构造、负载长度扩展、掩码生成与数据异或。
public interface WebSocketFrameEncoder extends ChannelHandler {}
WebSocket13FrameEncoder
WebSocket13FrameEncoder
将 WebSocket 帧编码成符合 RFC 6455 的二进制格式,支持负载长度扩展、可选掩码处理和分片发送,确保客户端数据按规范加密掩码。
public class WebSocket13FrameEncoder extends MessageToMessageEncoder<WebSocketFrame> implements WebSocketFrameEncoder {private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameEncoder.class);private static final byte OPCODE_CONT = 0x0;private static final byte OPCODE_TEXT = 0x1;private static final byte OPCODE_BINARY = 0x2;private static final byte OPCODE_CLOSE = 0x8;private static final byte OPCODE_PING = 0x9;private static final byte OPCODE_PONG = 0xA;private static final int GATHERING_WRITE_THRESHOLD = 1024;private final WebSocketFrameMaskGenerator maskGenerator;public WebSocket13FrameEncoder(boolean maskPayload) {this(maskPayload ? RandomWebSocketFrameMaskGenerator.INSTANCE : null);}public WebSocket13FrameEncoder(WebSocketFrameMaskGenerator maskGenerator) {this.maskGenerator = maskGenerator;}// 0 1 2 3 // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 // +-+-+-+-+-------+-+-------------+-------------------------------+// |F|R|R|R| opcode|M| Payload len | Extended payload length |// |I|S|S|S| (4) |A| (7) | (16/64 bits if needed) |// +-+-+-+-+-------+-+-------------+-------------------------------+// | Masking key (32 bits, only if MASK set to 1) |// +---------------------------------------------------------------+// | Masked/unmasked payload data |// +---------------------------------------------------------------+ @Overrideprotected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List<Object> out) throws Exception {final Buffer data = msg.binaryData();byte opcode = getOpCode(msg);int length = data.readableBytes();if (logger.isTraceEnabled()) {logger.trace("Encoding WebSocket Frame opCode={} length={}", opcode, length);}// 构造第一个字节(b0): FIN RSV1 RSV2 RSV3 opcodeint b0 = 0;if (msg.isFinalFragment()) {b0 |= 1 << 7; // FIN 位}b0 |= (msg.rsv() & 0x07) << 4; // RSV1、RSV2、RSV3b0 |= opcode & 0x7F; // 低 4 位 opcode// RFC 要求 PING 帧 payload 最大长度为 125 字节if (opcode == OPCODE_PING && length > 125) {throw new TooLongFrameException("invalid payload for PING (payload length must be <= 125, was " + length);}// 初始化输出 Buffer(header + payload + 掩码)Buffer buf = null;try {int maskLength = maskGenerator != null ? 4 : 0;if (length <= 125) {// 1 b0 + 1 (0x80-10000000 + length-0xxxxxxxx) + 4 mask + payload dataint size = 2 + maskLength + length;buf = ctx.bufferAllocator().allocate(size);buf.writeByte((byte) b0);byte b = (byte) (maskGenerator != null ? 0x80 | length : length);buf.writeByte(b);} else if (length <= 0xFFFF) {// 1 b0 + 1 0xFE-1111 1110 + 2 16bit Extended payload lengthint size = 4 + maskLength;if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {size += length;}buf = ctx.bufferAllocator().allocate(size);buf.writeByte((byte) b0);buf.writeByte((byte) (maskGenerator != null ? 0xFE : 126));buf.writeByte((byte) (length >>> 8 & 0xFF));buf.writeByte((byte) (length & 0xFF));} else {// 1 b0 + 1 0xFE-1111 1110 + 8 64bit Extended payload lengthint size = 10 + maskLength;if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {size += length;}buf = ctx.bufferAllocator().allocate(size);buf.writeByte((byte) b0);buf.writeByte((byte) (maskGenerator != null ? 0xFF : 127));buf.writeLong(length);}// Write payload// 掩码是 4 字节,RFC 规定客户端 必须 对每个字节 payload[i] ^= mask[i % 4]。if (maskGenerator != null) {int mask = maskGenerator.nextMask();buf.writeInt(mask);if (mask != 0) {int i = data.readerOffset();int end = data.writerOffset();int maskOffset = 0;for (; i < end; i++) {byte byteData = data.getByte(i);buf.writeByte((byte) (byteData ^ WebSocketUtil.byteAtIndex(mask, maskOffset++ & 3)));}out.add(buf);} else {addBuffers(buf, data, out);}} else {addBuffers(buf, data, out);}} catch (Throwable t) {if (buf != null) {buf.close();}throw t;}}private static byte getOpCode(WebSocketFrame msg) {if (msg instanceof TextWebSocketFrame) {return OPCODE_TEXT;}if (msg instanceof BinaryWebSocketFrame) {return OPCODE_BINARY;}if (msg instanceof PingWebSocketFrame) {return OPCODE_PING;}if (msg instanceof PongWebSocketFrame) {return OPCODE_PONG;}if (msg instanceof CloseWebSocketFrame) {return OPCODE_CLOSE;}if (msg instanceof ContinuationWebSocketFrame) {return OPCODE_CONT;}throw new UnsupportedOperationException("Cannot encode frame of type: " + msg.getClass().getName());}private static void addBuffers(Buffer buf, Buffer data, List<Object> out) {int readableBytes = data.readableBytes();if (buf.writableBytes() >= readableBytes) {// merge buffers as this is cheaper then a gathering write if the payload is small enoughbuf.writeBytes(data);out.add(buf);} else {out.add(buf);if (readableBytes > 0) {out.add(data.split());}}}
}
WebSocketFrameDecoder
WebSocketFrameDecoder
负责将符合 WebSocket 协议格式的二进制数据帧解码成 WebSocketFrame
,处理帧头解析、负载长度读取、掩码应用及分片重组。
public interface WebSocketFrameDecoder extends ChannelHandler {}
WebSocket13FrameDecoder
WebSocket13FrameDecoder 负责将接收到的二进制数据解析成符合 RFC 6455 规范的 WebSocket 帧对象,处理掩码解码、负载长度扩展、多帧分片合并及控制帧校验。
public class WebSocket13FrameDecoder extends ByteToMessageDecoder implements WebSocketFrameDecoder {private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameDecoder.class);private static final byte OPCODE_CONT = 0x0;private static final byte OPCODE_TEXT = 0x1;private static final byte OPCODE_BINARY = 0x2;private static final byte OPCODE_CLOSE = 0x8;private static final byte OPCODE_PING = 0x9;private static final byte OPCODE_PONG = 0xA;private final WebSocketDecoderConfig config;private int fragmentedFramesCount;private boolean frameFinalFlag;private boolean frameMasked;private int frameRsv;private int frameOpcode;private long framePayloadLength;private int mask;private int framePayloadLen1;private boolean receivedClosingHandshake;private State state = State.READING_FIRST;public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);}public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, boolean allowMaskMismatch) {this(WebSocketDecoderConfig.newBuilder().expectMaskedFrames(expectMaskedFrames).allowExtensions(allowExtensions).maxFramePayloadLength(maxFramePayloadLength).allowMaskMismatch(allowMaskMismatch).build());}public WebSocket13FrameDecoder(WebSocketDecoderConfig decoderConfig) {config = Objects.requireNonNull(decoderConfig, "decoderConfig");}private static int toFrameLength(long length) {if (length > Integer.MAX_VALUE) {throw new TooLongFrameException("frame length exceeds " + Integer.MAX_VALUE + ": " + length);} else {return (int) length;}}@Overrideprotected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {// Discard all data received if closing handshake was received before.if (receivedClosingHandshake) {in.skipReadableBytes(actualReadableBytes());return;}switch (state) {case READING_FIRST: {if (in.readableBytes() == 0) {return;}framePayloadLength = 0;// FIN, RSV, OPCODEbyte b = in.readByte();frameFinalFlag = (b & 0x80) != 0;frameRsv = (b & 0x70) >> 4;frameOpcode = b & 0x0F;if (logger.isTraceEnabled()) {logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);}state = State.READING_SECOND;}case READING_SECOND: {if (in.readableBytes() == 0) {return;}// MASK, PAYLOAD LEN 1byte b = in.readByte();frameMasked = (b & 0x80) != 0;framePayloadLen1 = b & 0x7F;if (frameRsv != 0 && !config.allowExtensions()) {protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);return;}if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {protocolViolation(ctx, in, "received a frame that is not masked as expected");return;}if (frameOpcode > 7) { // control frame (have MSB in opcode set)// control frames MUST NOT be fragmentedif (!frameFinalFlag) {protocolViolation(ctx, in, "fragmented control frame");return;}// control frames MUST have payload 125 octets or lessif (framePayloadLen1 > 125) {protocolViolation(ctx, in, "control frame with payload length > 125 octets");return;}// check for reserved control frame opcodesif (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING|| frameOpcode == OPCODE_PONG)) {protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);return;}// close frame : if there is a body, the first two bytes of the// body MUST be a 2-byte unsigned integer (in network byte// order) representing a getStatus codeif (frameOpcode == 8 && framePayloadLen1 == 1) {protocolViolation(ctx, in, "received close control frame with payload len 1");return;}} else { // data frame// check for reserved data frame opcodesif (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT|| frameOpcode == OPCODE_BINARY)) {protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);return;}// check opcode vs message fragmentation state 1/2if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {protocolViolation(ctx, in, "received continuation data frame outside fragmented message");return;}// check opcode vs message fragmentation state 2/2if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {protocolViolation(ctx, in,"received non-continuation data frame while inside fragmented message");return;}}state = State.READING_SIZE;}case READING_SIZE: {// Read frame payload lengthif (framePayloadLen1 == 126) {if (in.readableBytes() < 2) {return;}framePayloadLength = in.readUnsignedShort();if (framePayloadLength < 126) {protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");return;}} else if (framePayloadLen1 == 127) {if (in.readableBytes() < 8) {return;}framePayloadLength = in.readLong();// TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe// just check if it's negative?if (framePayloadLength < 65536) {protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");return;}} else {framePayloadLength = framePayloadLen1;}if (framePayloadLength > config.maxFramePayloadLength()) {protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,"Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");return;}if (logger.isTraceEnabled()) {logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);}state = State.MASKING_KEY;}case MASKING_KEY: {if (frameMasked) {if (in.readableBytes() < 4) {return;}mask = in.readInt();}state = State.PAYLOAD;}case PAYLOAD: {if (in.readableBytes() < framePayloadLength) {return;}Buffer payloadBuffer = null;try {payloadBuffer = in.readSplit(toFrameLength(framePayloadLength));// Now we have all the data, the next checkpoint must be the next// framestate = State.READING_FIRST;// Unmask data if neededif (frameMasked) {unmask(payloadBuffer);}// Processing ping/pong/close frames because they cannot be// fragmentedif (frameOpcode == OPCODE_PING) {WebSocketFrame frame = new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);payloadBuffer = null;ctx.fireChannelRead(frame);return;}if (frameOpcode == OPCODE_PONG) {WebSocketFrame frame = new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);payloadBuffer = null;ctx.fireChannelRead(frame);return;}if (frameOpcode == OPCODE_CLOSE) {receivedClosingHandshake = true;checkCloseFrameBody(ctx, payloadBuffer);WebSocketFrame frame = new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);payloadBuffer = null;ctx.fireChannelRead(frame);return;}// Processing for possible fragmented messages for text and binary// framesif (frameFinalFlag) {// Final frame of the sequence. Apparently ping frames are// allowed in the middle of a fragmented messagefragmentedFramesCount = 0;} else {// Increment counterfragmentedFramesCount++;}// Return the frameif (frameOpcode == OPCODE_TEXT) {WebSocketFrame frame = new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);payloadBuffer = null;ctx.fireChannelRead(frame);return;} else if (frameOpcode == OPCODE_BINARY) {WebSocketFrame frame = new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);payloadBuffer = null;ctx.fireChannelRead(frame);return;} else if (frameOpcode == OPCODE_CONT) {WebSocketFrame frame = new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);payloadBuffer = null;ctx.fireChannelRead(frame);return;} else {throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "+ frameOpcode);}} finally {if (payloadBuffer != null) {payloadBuffer.close();}}}case CORRUPT: {if (in.readableBytes() > 0) {// If we don't keep reading Netty will throw an exception saying// we can't return null if no bytes read and state not changed.in.readByte();}return;}default:throw new Error("Shouldn't reach here.");}}private void unmask(Buffer frame) {int base = frame.readerOffset();int len = frame.readableBytes();int index = 0;int intMask = mask;if (intMask == 0) {// If the mask is 0 we can just return directly as the XOR operations will just produce the same value.return;}for (; index + 3 < len; index += Integer.BYTES) {int off = base + index;frame.setInt(off, frame.getInt(off) ^ intMask);}int maskOffset = 0;for (; index < len; index++) {int off = base + index;frame.setByte(off, (byte) (frame.getByte(off) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3)));}}private void protocolViolation(ChannelHandlerContext ctx, Buffer in, String reason) {protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);}private void protocolViolation(ChannelHandlerContext ctx, Buffer in, WebSocketCloseStatus status, String reason) {protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));}private void protocolViolation(ChannelHandlerContext ctx, Buffer in, CorruptedWebSocketFrameException ex) {state = State.CORRUPT;int readableBytes = in.readableBytes();if (readableBytes > 0) {// Fix for memory leak, caused by ByteToMessageDecoder#channelRead:// buffer 'cumulation' is released ONLY when no more readable bytes available.in.skipReadableBytes(readableBytes);}if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {Object closeMessage;if (receivedClosingHandshake) {closeMessage = ctx.bufferAllocator().allocate(0);} else {WebSocketCloseStatus closeStatus = ex.closeStatus();String reasonText = ex.getMessage();if (reasonText == null) {reasonText = closeStatus.reasonText();}closeMessage = new CloseWebSocketFrame(ctx.bufferAllocator(), closeStatus, reasonText);}ctx.writeAndFlush(closeMessage).addListener(ctx, ChannelFutureListeners.CLOSE);}throw ex;}/** */protected void checkCloseFrameBody(ChannelHandlerContext ctx, Buffer buffer) {if (buffer == null || buffer.readableBytes() <= 0) {return;}if (buffer.readableBytes() == 1) {protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");}// Save reader offset.int offset = buffer.readerOffset();try {// Must have 2 byte integer within the valid range.int statusCode = buffer.readShort();if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);}// May have UTF-8 message.if (buffer.readableBytes() > 0) {try {new Utf8Validator().check(buffer);} catch (CorruptedWebSocketFrameException ex) {protocolViolation(ctx, buffer, ex);}}} finally {// Restore reader offset.buffer.readerOffset(offset);}}enum State {READING_FIRST,READING_SECOND,READING_SIZE,MASKING_KEY,PAYLOAD,CORRUPT}
}