skynet-socket.lua源码分析
skynet-socket.lua源码分析
- 源码
- 模块初始化和核心数据结构
- 引入依赖和常量
- 核心数据结构
- 协程管理和挂起机制
- 协程控制函数
- Socket消息类型处理
- 类型1: 数据到达 (SKYNET_SOCKET_TYPE_DATA)
- 类型2: 连接建立 (SKYNET_SOCKET_TYPE_CONNECT)
- 类型3: 连接关闭 (SKYNET_SOCKET_TYPE_CLOSE)
- 类型4: 接受连接 (SKYNET_SOCKET_TYPE_ACCEPT)
- 其他类型处理
- 协议注册和消息分发
- 核心API方法
- 连接建立相关
- 数据读取相关
- 数据写入和连接管理
- UDP相关功能
- 业务调用链路分析
- 服务启动监听流程
- 客户端连接处理流程
- 数据发送流程
源码
local driver = require "skynet.socketdriver"
local skynet = require "skynet"
local skynet_core = require "skynet.core"
local assert = assertlocal BUFFER_LIMIT = 128 * 1024
local socket = {} -- api
local socket_pool = setmetatable( -- store all socket object{},{ __gc = function(p)for id,v in pairs(p) dodriver.close(id)p[id] = nilendend}
)local socket_onclose = {}
local socket_message = {}local function wakeup(s)local co = s.coif co thens.co = nilskynet.wakeup(co)end
endlocal function pause_socket(s, size)if s.pause ~= nil thenreturnendif size thenskynet.error(string.format("Pause socket (%d) size : %d" , s.id, size))elseskynet.error(string.format("Pause socket (%d)" , s.id))enddriver.pause(s.id)s.pause = trueskynet.yield() -- there are subsequent socket messages in mqueue, maybe.
endlocal function suspend(s)assert(not s.co)s.co = coroutine.running()if s.pause thenskynet.error(string.format("Resume socket (%d)", s.id))driver.start(s.id)skynet.wait(s.co)s.pause = nilelseskynet.wait(s.co)end-- wakeup closing corouting every time suspend,-- because socket.close() will wait last socket buffer operation before clear the buffer.if s.closing thenskynet.wakeup(s.closing)end
end-- read skynet_socket.h for these macro
-- SKYNET_SOCKET_TYPE_DATA = 1
socket_message[1] = function(id, size, data)local s = socket_pool[id]if s == nil thenskynet.error("socket: drop package from " .. id)driver.drop(data, size)returnendlocal sz = driver.push(s.buffer, s.pool, data, size)local rr = s.read_requiredlocal rrt = type(rr)if rrt == "number" then-- read sizeif sz >= rr thens.read_required = nilif sz > BUFFER_LIMIT thenpause_socket(s, sz)endwakeup(s)endelseif s.buffer_limit and sz > s.buffer_limit thenskynet.error(string.format("socket buffer overflow: fd=%d size=%d", id , sz))driver.close(id)returnendif rrt == "string" then-- read lineif driver.readline(s.buffer,nil,rr) thens.read_required = nilif sz > BUFFER_LIMIT thenpause_socket(s, sz)endwakeup(s)endelseif sz > BUFFER_LIMIT and not s.pause thenpause_socket(s, sz)endend
end-- SKYNET_SOCKET_TYPE_CONNECT = 2
socket_message[2] = function(id, ud , addr)local s = socket_pool[id]if s == nil thenreturnend-- log remote addrif not s.connected then -- resume may also post connect messageif s.listen thens.addr = addrs.port = udends.connected = truewakeup(s)end
end-- SKYNET_SOCKET_TYPE_CLOSE = 3
socket_message[3] = function(id)local s = socket_pool[id]if s thens.connected = falsewakeup(s)elsedriver.close(id)endlocal cb = socket_onclose[id]if cb thencb(id)socket_onclose[id] = nilend
end-- SKYNET_SOCKET_TYPE_ACCEPT = 4
socket_message[4] = function(id, newid, addr)local s = socket_pool[id]if s == nil thendriver.close(newid)returnends.callback(newid, addr)
end-- SKYNET_SOCKET_TYPE_ERROR = 5
socket_message[5] = function(id, _, err)local s = socket_pool[id]if s == nil thendriver.shutdown(id)skynet.error("socket: error on unknown", id, err)returnendif s.callback thenskynet.error("socket: accept error:", err)returnendif s.connected thenskynet.error("socket: error on", id, err)elseif s.connecting thens.connecting = errends.connected = falsedriver.shutdown(id)wakeup(s)
end-- SKYNET_SOCKET_TYPE_UDP = 6
socket_message[6] = function(id, size, data, address)local s = socket_pool[id]if s == nil or s.callback == nil thenskynet.error("socket: drop udp package from " .. id)driver.drop(data, size)returnendlocal str = skynet.tostring(data, size)skynet_core.trash(data, size)s.callback(str, address)
endlocal function default_warning(id, size)local s = socket_pool[id]if not s thenreturnendskynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d)", size, id))
end-- SKYNET_SOCKET_TYPE_WARNING
socket_message[7] = function(id, size)local s = socket_pool[id]if s thenlocal warning = s.on_warning or default_warningwarning(id, size)end
endskynet.register_protocol {name = "socket",id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6unpack = driver.unpack,dispatch = function (_, _, t, ...)socket_message[t](...)end
}local function connect(id, func)local newbufferif func == nil thennewbuffer = driver.buffer()endlocal s = {id = id,buffer = newbuffer,pool = newbuffer and {},connected = false,connecting = true,read_required = false,co = false,callback = func,protocol = "TCP",}assert(not socket_onclose[id], "socket has onclose callback")local s2 = socket_pool[id]if s2 and not s2.listen thenerror("socket is not closed")endsocket_pool[id] = ssuspend(s)local err = s.connectings.connecting = nilif s.connected thenreturn idelsesocket_pool[id] = nilreturn nil, errend
endfunction socket.open(addr, port)local id = driver.connect(addr,port)return connect(id)
endfunction socket.bind(os_fd)local id = driver.bind(os_fd)return connect(id)
endfunction socket.stdin()return socket.bind(0)
endfunction socket.start(id, func)driver.start(id)return connect(id, func)
endfunction socket.pause(id)local s = socket_pool[id]if s == nil thenreturnendpause_socket(s)
endfunction socket.shutdown(id)local s = socket_pool[id]if s then-- the framework would send SKYNET_SOCKET_TYPE_CLOSE , need close(id) laterdriver.shutdown(id)end
endfunction socket.close_fd(id)assert(socket_pool[id] == nil,"Use socket.close instead")driver.close(id)
endfunction socket.close(id)local s = socket_pool[id]if s == nil thenreturnenddriver.close(id)if s.connected thens.pause = false -- Do not resume this fd if it paused.if s.co then-- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately-- wait reading coroutine read the buffer.assert(not s.closing)s.closing = coroutine.running()skynet.wait(s.closing)elsesuspend(s)ends.connected = falseendsocket_pool[id] = nil
endfunction socket.read(id, sz)local s = socket_pool[id]assert(s)if sz == nil then-- read some byteslocal ret = driver.readall(s.buffer, s.pool)if ret ~= "" thenreturn retendif not s.connected thenreturn false, retendassert(not s.read_required)s.read_required = 0suspend(s)ret = driver.readall(s.buffer, s.pool)if ret ~= "" thenreturn retelsereturn false, retendendlocal ret = driver.pop(s.buffer, s.pool, sz)if ret thenreturn retendif s.closing or not s.connected thenreturn false, driver.readall(s.buffer, s.pool)endassert(not s.read_required)s.read_required = szsuspend(s)ret = driver.pop(s.buffer, s.pool, sz)if ret thenreturn retelsereturn false, driver.readall(s.buffer, s.pool)end
endfunction socket.readall(id)local s = socket_pool[id]assert(s)if not s.connected thenlocal r = driver.readall(s.buffer, s.pool)return r ~= "" and rendassert(not s.read_required)s.read_required = truesuspend(s)assert(s.connected == false)return driver.readall(s.buffer, s.pool)
endfunction socket.readline(id, sep)sep = sep or "\n"local s = socket_pool[id]assert(s)local ret = driver.readline(s.buffer, s.pool, sep)if ret thenreturn retendif not s.connected thenreturn false, driver.readall(s.buffer, s.pool)endassert(not s.read_required)s.read_required = sepsuspend(s)if s.connected thenreturn driver.readline(s.buffer, s.pool, sep)elsereturn false, driver.readall(s.buffer, s.pool)end
endfunction socket.block(id)local s = socket_pool[id]if not s or not s.connected thenreturn falseendassert(not s.read_required)s.read_required = 0suspend(s)return s.connected
endsocket.write = assert(driver.send)
socket.lwrite = assert(driver.lsend)
socket.header = assert(driver.header)function socket.invalid(id)return socket_pool[id] == nil
endfunction socket.disconnected(id)local s = socket_pool[id]if s thenreturn not(s.connected or s.connecting)end
endfunction socket.listen(host, port, backlog)if port == nil thenhost, port = string.match(host, "([^:]+):(.+)$")port = tonumber(port)endlocal id = driver.listen(host, port, backlog)local s = {id = id,connected = false,listen = true,}assert(socket_pool[id] == nil)socket_pool[id] = ssuspend(s)return id, s.addr, s.port
end-- abandon use to forward socket id to other service
-- you must call socket.start(id) later in other service
function socket.abandon(id)local s = socket_pool[id]if s thens.connected = falsewakeup(s)socket_onclose[id] = nilsocket_pool[id] = nilend
endfunction socket.limit(id, limit)local s = assert(socket_pool[id])s.buffer_limit = limit
end---------------------- UDPlocal function create_udp_object(id, cb)assert(not socket_pool[id], "socket is not closed")socket_pool[id] = {id = id,connected = true,protocol = "UDP",callback = cb,}
endfunction socket.udp(callback, host, port)local id = driver.udp(host, port)create_udp_object(id, callback)return id
endfunction socket.udp_connect(id, addr, port, callback)local obj = socket_pool[id]if obj thenassert(obj.protocol == "UDP")if callback thenobj.callback = callbackendelsecreate_udp_object(id, callback)enddriver.udp_connect(id, addr, port)
endfunction socket.udp_listen(addr, port, callback)local id = driver.udp_listen(addr, port)create_udp_object(id, callback)return id
endfunction socket.udp_dial(addr, port, callback)local id = driver.udp_dial(addr, port)create_udp_object(id, callback)return id
endsocket.sendto = assert(driver.udp_send)
socket.udp_address = assert(driver.udp_address)
socket.netstat = assert(driver.info)
socket.resolve = assert(driver.resolve)function socket.warning(id, callback)local obj = socket_pool[id]assert(obj)obj.on_warning = callback
endfunction socket.onclose(id, callback)socket_onclose[id] = callback
endreturn socket
模块初始化和核心数据结构
引入依赖和常量
local driver = require "skynet.socketdriver" -- 底层C驱动
local skynet = require "skynet"
local skynet_core = require "skynet.core"local BUFFER_LIMIT = 128 * 1024 -- 缓冲区限制128KB
核心数据结构
local socket = {} -- 对外API
local socket_pool = setmetatable({}, { -- 所有socket对象池__gc = function(p) -- GC时自动关闭所有socketfor id,v in pairs(p) dodriver.close(id)p[id] = nilendend
})local socket_onclose = {} -- socket关闭回调
local socket_message = {} -- 消息类型处理函数
协程管理和挂起机制
协程控制函数
local function wakeup(s)local co = s.coif co thens.co = nilskynet.wakeup(co) -- 唤醒挂起的协程end
endlocal function pause_socket(s, size)if s.pause ~= nil then return enddriver.pause(s.id) -- 底层暂停接收数据s.pause = trueskynet.yield() -- 让出CPU,处理其他消息
endlocal function suspend(s)assert(not s.co)s.co = coroutine.running() -- 保存当前协程if s.pause thendriver.start(s.id) -- 恢复数据接收skynet.wait(s.co) -- 等待唤醒s.pause = nilelseskynet.wait(s.co) -- 直接等待end-- 如果有关闭操作在等待,唤醒它if s.closing thenskynet.wakeup(s.closing)end
end
Socket消息类型处理
Skynet定义了7种socket消息类型
类型1: 数据到达 (SKYNET_SOCKET_TYPE_DATA)
socket_message[1] = function(id, size, data)local s = socket_pool[id]if s == nil thenskynet.error("socket: drop package from " .. id)driver.drop(data, size) -- 丢弃数据returnendlocal sz = driver.push(s.buffer, s.pool, data, size) -- 数据压入缓冲区-- 根据读取需求唤醒等待的协程local rr = s.read_requiredif type(rr) == "number" then -- 需要读取指定字节数if sz >= rr thens.read_required = nilif sz > BUFFER_LIMIT thenpause_socket(s, sz) -- 缓冲区过大,暂停接收endwakeup(s) -- 唤醒读取协程endelse-- 其他读取模式处理...end
end
类型2: 连接建立 (SKYNET_SOCKET_TYPE_CONNECT)
socket_message[2] = function(id, ud, addr)local s = socket_pool[id]if s == nil then return endif not s.connected thenif s.listen then -- 监听sockets.addr = addrs.port = udends.connected = true -- 标记为已连接wakeup(s) -- 唤醒等待连接的协程end
end
类型3: 连接关闭 (SKYNET_SOCKET_TYPE_CLOSE)
socket_message[3] = function(id)local s = socket_pool[id]if s thens.connected = falsewakeup(s) -- 唤醒所有等待的协程elsedriver.close(id) -- 直接关闭end-- 执行关闭回调local cb = socket_onclose[id]if cb thencb(id)socket_onclose[id] = nilend
end
类型4: 接受连接 (SKYNET_SOCKET_TYPE_ACCEPT)
socket_message[4] = function(id, newid, addr)local s = socket_pool[id]if s == nil thendriver.close(newid) -- 监听socket已关闭,拒绝连接returnends.callback(newid, addr) -- 调用accept回调
end
其他类型处理
- 类型5: 错误处理
- 类型6: UDP数据包
- 类型7: 发送缓冲区警告
协议注册和消息分发
skynet.register_protocol {name = "socket",id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6unpack = driver.unpack, -- 使用驱动解包dispatch = function (_, _, t, ...)socket_message[t](...) -- 根据类型分发处理end
}
核心API方法
连接建立相关
function socket.open(addr, port)local id = driver.connect(addr,port) -- 底层连接return connect(id) -- 等待连接完成
endfunction socket.listen(host, port, backlog)local id = driver.listen(host, port, backlog) -- 创建监听socketlocal s = {id = id,connected = false,listen = true,}socket_pool[id] = ssuspend(s) -- 等待监听成功return id, s.addr, s.port -- 返回实际监听的地址和端口
endfunction socket.start(id, func)driver.start(id) -- 开始接收数据return connect(id, func) -- 对于监听socket,func是accept回调
end
数据读取相关
function socket.read(id, sz)local s = socket_pool[id]assert(s)if sz == nil then-- 读取所有可用数据local ret = driver.readall(s.buffer, s.pool)if ret ~= "" then return ret endif not s.connected then return false, ret ends.read_required = 0 -- 标记需要数据suspend(s) -- 挂起等待数据ret = driver.readall(s.buffer, s.pool)return ret ~= "" and ret or false, retelse-- 读取指定大小数据local ret = driver.pop(s.buffer, s.pool, sz)if ret then return ret endif s.closing or not s.connected thenreturn false, driver.readall(s.buffer, s.pool)ends.read_required = sz -- 设置读取需求suspend(s) -- 等待足够数据ret = driver.pop(s.buffer, s.pool, sz)return ret or false, driver.readall(s.buffer, s.pool)end
endfunction socket.readline(id, sep)sep = sep or "\n"local s = socket_pool[id]assert(s)local ret = driver.readline(s.buffer, s.pool, sep)if ret then return ret endif not s.connected thenreturn false, driver.readall(s.buffer, s.pool)ends.read_required = sep -- 设置为行读取模式suspend(s)if s.connected thenreturn driver.readline(s.buffer, s.pool, sep)elsereturn false, driver.readall(s.buffer, s.pool)end
end
数据写入和连接管理
socket.write = assert(driver.send) -- 异步发送
socket.lwrite = assert(driver.lsend) -- 低级别发送function socket.close(id)local s = socket_pool[id]if s == nil then return enddriver.close(id) -- 底层关闭if s.connected thens.pause = falseif s.co then-- 有协程在读取,等待读取完成assert(not s.closing)s.closing = coroutine.running()skynet.wait(s.closing)elsesuspend(s) -- 等待清理ends.connected = falseendsocket_pool[id] = nil
end
UDP相关功能
function socket.udp(callback, host, port)local id = driver.udp(host, port)create_udp_object(id, callback) -- 创建UDP socket对象return id
endsocket.sendto = assert(driver.udp_send) -- UDP发送
业务调用链路分析
服务启动监听流程
-- gateserver.lua 中的调用
function CMD.open(source, conf)local address = conf.address or "0.0.0.0"local port = assert(conf.port)socket = socketdriver.listen(address, port, backlog) -- 创建监听socketsocketdriver.start(socket) -- 开始接受连接
end
客户端连接处理流程
1. 客户端连接 → 底层驱动 → socket_message[4] → handler.connect() → gateserver.openclient()2. 数据到达 → socket_message[1] → 数据压入缓冲区 → 检查读取需求 → wakeup()唤醒读取协程3. 业务读取 → socket.read() → 缓冲区有数据立即返回 / 无数据则suspend()等待
数据发送流程
-- 业务代码调用
socket.write(fd, data) → driver.send(fd, data) -- 异步发送到底层
可以把socket.lua想象成一个高效的快递分拣中心:
- socket_pool = 快递货架(存放所有包裹)
- buffer = 临时存放区(数据缓冲区)
- 协程机制 = 智能调度系统(有人取件时才通知)
- driver = 装卸工人(底层实际操作)
工作流程:
1. 收货(数据到达):
- 快递车(网络数据)到达 → 分拣员(socket_message)处理
- 根据标签(消息类型)放到对应货架
- 如果有人预订(read_required),就打电话通知(wakeup)
2. 取货(数据读取):
- 客户(业务代码)来取件 → 看货架有没有
- 有货直接拿走 → 没货就登记需求(read_required)然后等待(suspend)
3. 发货(数据发送):
- 客户要寄件 → 直接交给装卸工(driver.send)处理
- 装卸工负责打包发送,不阻塞分拣中心
4. 特殊服务:
- 暂停服务:货架太满时暂停收货(pause_socket)
- 超时提醒:发货堆积时发出警告(socket_message[7])
- 连接管理:新客户登记、老客户离开的接待流程