无奈!我用go写了个MySQL服务
一个程序员的“被迫营业”故事
序:同事的困境
嘿,各位程序员大佬们!今天我要给你们讲一个既荒诞又真实的故事——关于我如何被同事的易语言代码“逼”成了“MySQL协议专家”,注意引号,不是真专家。
事情是这样的,我有个同事,有一个程序是用易语言写的,古董级的。最近要做升级,要缓存一个4G左右的数据,可是易语言是32位程序,做不到,于是我就用go开发了一个HTTP接口,数据我缓存,他用易语言调用这个接口,结果遇到了一个世纪难题:易语言使用http读文件不是线程不安全!是的,你没听错,在2025年,还有人在为易语言的线程安全问题挠头。
灵感乍现:“既然易语言搞不定HTTP,那我们为什么不用MySQL协议呢?”
那天,他调试来调试去,线程安全问题就是解决不了!最后,我盯着他足足看了三十秒钟,突然脑子里闪过一道光——对啊!MySQL协议是成熟稳定的数据库协议,几乎所有语言都有完善的驱动支持,包括易语言!
于是,一个大胆的想法诞生了:写一个简单的MySQL协议服务端,将HTTP接口伪装成MySQL查询。这样,同事就可以用易语言通过MySQL驱动轻松调用,完美避开线程安全的坑!
开发历程:从0到1的MySQL协议实现
说干就干!我用Go语言(感谢Go的并发模型!)开始了这个"伪MySQL服务器"的开发。过程中遇到了不少挑战:
- 协议解析:MySQL协议虽然公开,但细节繁多,特别是握手认证部分
- 连接处理:要支持多客户端同时连接,Go的goroutine正好派上用场
- 兼容性处理:不同语言的MySQL驱动初始化时会发送各种查询,必须一一处理
功能特性:麻雀虽小,五脏俱全
经过一番努力,这个"不正经"的MySQL服务器已经具备了以下功能:
- ✅ MySQL协议兼容,可以被各种语言的MySQL驱动正常连接
- ✅ 支持基本的数据库操作命令
- ✅ 处理客户端初始化查询(比如字符集、校对规则等)
- ✅ 自定义SQL命令支持:这里用了两个示例接口,没有真实的业务代码
- ✅ 数据库和表的模拟(通过文件系统目录结构)
自定义命令使用示例
GET MB 命令 - 字符串回显功能
mysql> get mb hello world;
+--------+-------------+
| key | value |
+--------+-------------+
| result | hello world |
+--------+-------------+
1 row in set
2. GET MD5 命令 - MD5计算功能
mysql> get md5 test123;
+--------+----------------------------------+
| key | value |
+--------+----------------------------------+
| result | 22b75d6007e06f4a959d1b1d69b4c4bd |
+--------+----------------------------------+
1 row in set
核心代码解析
1. 自定义命令处理机制
自定义命令是这个服务的核心功能之一,下面我们来看看它是如何实现的:
// 处理自定义SQL命令
if strings.HasPrefix(strings.ToUpper(sql), "GET MB") {parts := strings.Fields(strings.ToUpper(sql))var input stringif len(parts) > 2 {input = strings.Join(parts[2:], " ") // 提取命令参数} else {input = "default" // 默认值处理}// 返回输入的字符串mbResult := input// 构建MySQL响应结构rows := [][]string{{"result", mbResult},}mr := MysqlResponse{// 定义字段结构Fs: []FieldProtocol{...},Rows: rows,}c.Conn.Write(mr.GetBytes()) // 发送响应
} else if strings.HasPrefix(strings.ToUpper(sql), "GET MD5") {// MD5命令处理逻辑...
} else {// 处理标准SQL命令...
}
这段代码展示了自定义命令的处理流程:
- 命令识别:通过字符串前缀判断命令类型
- 参数解析:提取命令后面的参数部分
- 业务处理:执行相应的业务逻辑(字符串回显或MD5计算)
- 响应构建:创建标准的MySQL协议响应结构
- 数据发送:将响应序列化为二进制数据并发送
2. MySQL协议响应结构
服务通过MysqlResponse
结构体来构建符合MySQL协议的响应数据:
// MysqlResponse MySQL响应结构
type MysqlResponse struct {Fs []FieldProtocol // 字段定义Rows [][]string // 数据行
}// GetBytes 将响应转换为二进制数据
func (mr *MysqlResponse) GetBytes() []byte {// 构建响应头部var buf bytes.Bufferbuf.WriteByte(0x01) // 包头// 写入字段定义for _, field := range mr.Fs {buf.Write(field.GetBytes())}// 写入EOF包buf.Write(getEOFPacket())// 写入数据行for _, row := range mr.Rows {buf.Write(getRowPacket(row))}// 写入最终的EOF包buf.Write(getEOFPacket())return buf.Bytes()
}
这个结构严格遵循MySQL的文本协议格式,包含:
- 字段定义:描述返回数据的结构
- 数据行:实际的查询结果
- 特殊标记:如EOF包,用于分隔不同的协议阶段
3. 连接处理与认证流程
服务使用Go的goroutine为每个客户端连接创建独立的处理线程:
// Start 启动连接处理
func (c *ConnHandle) Start() {defer func() {c.Conn.Close()wg.Done()}()// 发送握手包c.WriteOnePack()// 读取并解析客户端响应data, err := c.ReadOnePack()if err != nil {return}// 处理认证username, password, dbname, err := parseClientHandshakePacket(data)if err != nil || !isPassScrambleMysqlNativePassword(password, c.Salt) {c.writeErrorPacket(1045, "28000", "用户名或密码错误")return}// 发送认证成功响应c.writeOKPacket()// 主命令处理循环for !exit {err = c.handleNextCommand()if err != nil {break}}
}
这个连接处理流程包括:
- 握手初始化:服务器发送握手包给客户端
- 身份验证:验证用户名密码(本实现中简化了验证逻辑)
- 命令循环:持续接收并处理客户端命令
- 资源清理:连接结束时释放资源
4. 配置文件读取
服务通过readConfig
函数读取配置文件:
// readConfig 读取配置文件
func readConfig(configPath string) (string, string, error) {pz, err := config.ReadDefault(configPath)if err != nil {Loger.Error(err)return err}if dbRoot, err = pz.String("mysqld", "datadir"); err != nil {Loger.Error(err)return err}return nil
}
这个函数负责从配置文件中读取服务配置,目前只有数据目录路径
配置文件如下:
my.ini
[mysqld]
datadir=./data
程序源码如下:
package mainimport ("bytes""crypto/md5""crypto/rand""crypto/sha1""encoding/binary""encoding/hex""errors""flag""fmt""io""net""os""path/filepath""strconv""strings""sync""time""gitcode.com/jjgtmgx/mgxlog""github.com/larspensjo/config"
)// 最大数据包大小
const (MaxPacketSize = (1 << 24) - 1 // 服务器支持的最大数据包大小ProtocolVersion = 10 // MySQL协议版本,固定为10
)// 认证方式常量
const (// MysqlNativePassword 身份验证方式MysqlNativePassword = "mysql_native_password"
)// 能力标志常量
const (// CapabilityClientFoundRows 返回找到的行数而不是受影响的行数CapabilityClientFoundRows = 1 << 1// CapabilityClientConnectWithDB 可以在连接时指定数据库CapabilityClientConnectWithDB = 1 << 3// CapabilityClientProtocol41 新的4.1协议,必须支持CapabilityClientProtocol41 = 1 << 9// CapabilityClientSecureConnection 新的4.1身份验证方式CapabilityClientSecureConnection = 1 << 15// CapabilityClientMultiStatements 支持在COM_QUERY和COM_STMT_PREPARE中处理多个语句CapabilityClientMultiStatements = 1 << 16// CapabilityClientPluginAuth 客户端支持插件身份验证CapabilityClientPluginAuth = 1 << 19// CapabilityClientConnAttr 允许在Protocol::HandshakeResponse41中使用连接属性CapabilityClientConnAttr = 1 << 20// CapabilityClientDeprecateEOF 期望在文本结果集的行之后使用OK(而不是EOF)CapabilityClientDeprecateEOF = 1 << 24
)// 数据包类型常量
const (// ComQuit 客户端请求关闭连接ComQuit = 0x01// ComInitDB 客户端请求切换数据库ComInitDB = 0x02// ComQuery 客户端发送SQL查询ComQuery = 0x03// ComPing 客户端发送ping请求ComPing = 0x0e// ComSetOption 客户端设置选项ComSetOption = 0x1b// OKPacket OK数据包的头部标识OKPacket = 0x00// EOFPacket EOF数据包的头部标识EOFPacket = 0xfe// ErrPacket 错误数据包的头部标识ErrPacket = 0xff
)// 日志记录
var Loger, _ = mgxlog.NewMgxLog("runlog/", 10*1024*1024, 100, 3, 1000)var exit bool = false
var wg sync.WaitGroup
var tidchan = make(chan uint32)var dbRoot stringfunc main() {defer Loger.Flush()addr := flag.String("addr", ":3307", "http service address")config := flag.String("config", "./my.ini", "configuration file path")flag.Parse()// 读取 my.ini 获取 datadirif err := readConfig(*config); err != nil {return}go CreateTid()wg.Add(1)go StartServer(*addr)for {var cmd stringfmt.Scanf("%s", &cmd)if cmd == "exit" {exit = truebreak}fmt.Println("未知命令")fmt.Println("exit 退出程序")}wg.Wait()
}// readConfig 读取配置文件
func readConfig(configPath string) error {pz, err := config.ReadDefault(configPath)if err != nil {Loger.Error(err)return err}if dbRoot, err = pz.String("mysqld", "datadir"); err != nil {Loger.Error(err)return err}return nil
}// StartServer 启动服务器
func StartServer(addr string) {defer wg.Done()var netListen net.Listenerfor !exit {if netListen == nil {var err errorif netListen, err = net.Listen("tcp", addr); err != nil {Loger.Error(err)} else {go func() {for !exit {conn, err := netListen.Accept()if err != nil {continue}ch := ConnHandle{Conn: conn}go ch.Start()}}()}} else {time.Sleep(2 * time.Second)}}if netListen != nil {netListen.Close()}
}// CreateTid 生成事务ID
func CreateTid() {tid := uint32(1)for {tidchan <- tidtid++if tid > 999999999 {tid = 1}}
}var ServerVersion = "5.5.15"type ConnHandle struct {Conn net.Connsequence uint8Capabilities uint32SchemaName stringCharacterSet uint8User string
}// Start 启动连接处理
func (ch *ConnHandle) Start() {defer ch.Conn.Close()salt, err := ch.WriteOnePack()if err != nil {Loger.Error(ch.Conn.RemoteAddr().String(), err)return}b, err := ch.ReadOnePack()if err != nil {Loger.Error(ch.Conn.RemoteAddr().String(), err)return}user, _, authResponse, err := ch.parseClientHandshakePacket(true, b)if err != nil {Loger.Errorf("无法解析来自 %s 的客户端握手响应: %v", ch.Conn, err)return}if !isPassScrambleMysqlNativePassword(authResponse, salt, "root") {ch.writeErrorPacket(1045, "00001", "用户名或密码错误")return}ch.User = userif err := ch.writeOKPacket(0, 0, 0, 0); err != nil {Loger.Errorf("无法向 %s 写入OK包: %v", ch.Conn, err)return}for !exit {err := ch.handleNextCommand()if err != nil {Loger.Error(err)return}}
}// handleNextCommand 处理下一个命令
func (c *ConnHandle) handleNextCommand() error {c.sequence = 0data, err := c.readEphemeralPacket()if err != nil {return err}switch data[0] {case ComQuit:return errors.New("ComQuit")case ComInitDB:dbName := string(data[1:])if _, err := os.Stat(filepath.Join(dbRoot, dbName)); os.IsNotExist(err) {c.writeErrorPacket(1049, "42000", fmt.Sprintf("未知数据库 '%s'", dbName))return nil}c.SchemaName = dbNamec.writeOKPacket(0, 0, 0, 0)case ComQuery:sql := string(data[1:])sql = collapseSpaces(sql)sql = strings.TrimRight(sql, ";")sql = strings.TrimSpace(sql)sql = strings.Join(strings.Fields(sql), " ")// 处理自定义SQL命令: get mb xxxif strings.HasPrefix(strings.ToUpper(sql), "GET MB") {parts := strings.Fields(strings.ToUpper(sql))var input stringif len(parts) > 2 {input = strings.Join(parts[2:], " ")} else {input = "default"}// 返回输入的字符串mbResult := inputrows := [][]string{{"result", mbResult},}mr := MysqlResponse{Fs: []FieldProtocol{{Catalog: "def",Database: "",Table: "",OriginalTable: "",Name: "key",OriginalName: "key",Charset: 33,Length: 50,Type: 253,Flags: 1,Decimals: 0,},{Catalog: "def",Database: "",Table: "",OriginalTable: "",Name: "value",OriginalName: "value",Charset: 33,Length: 50,Type: 253,Flags: 1,Decimals: 0,},},Rows: rows,}c.Conn.Write(mr.GetBytes())} else if strings.HasPrefix(strings.ToUpper(sql), "GET MD5") {parts := strings.Fields(strings.ToUpper(sql))var input stringif len(parts) > 2 {input = strings.Join(parts[2:], " ")} else {input = "default"}// 计算MD5值hash := md5.Sum([]byte(input))md5Result := hex.EncodeToString(hash[:])rows := [][]string{{"result", md5Result},}mr := MysqlResponse{Fs: []FieldProtocol{{Catalog: "def",Database: "",Table: "",OriginalTable: "",Name: "key",OriginalName: "key",Charset: 33,Length: 50,Type: 253,Flags: 1,Decimals: 0,},{Catalog: "def",Database: "",Table: "",OriginalTable: "",Name: "value",OriginalName: "value",Charset: 33,Length: 50,Type: 253,Flags: 1,Decimals: 0,},},Rows: rows,}c.Conn.Write(mr.GetBytes())} else if strings.HasPrefix(strings.ToUpper(sql), "SELECT COLLATIONS") {rows := [][]string{}for i := 0; i < 1000; i++ {rows = append(rows, []string{strconv.Itoa(i)})}mr := MysqlResponse{Fs: []FieldProtocol{{Catalog: "def",Database: "information_schema",Table: "COLLATIONS",OriginalTable: "COLLATIONS",Name: "result",OriginalName: "result",Charset: 33,Length: 96,Type: 253,Flags: 1,Decimals: 0,},},Rows: rows,}c.Conn.Write(mr.GetBytes())} else if strings.HasPrefix(strings.ToUpper(sql), "SHOW DATABASES") {dbs, err := os.ReadDir(dbRoot)if err != nil {c.writeErrorPacket(1049, "HY000", "无法读取数据库目录")return nil}var rows [][]stringfor _, entry := range dbs {if entry.IsDir() {rows = append(rows, []string{entry.Name()})}}mr := MysqlResponse{Fs: []FieldProtocol{{Catalog: "def",Database: "information_schema",Table: "SCHEMATA",OriginalTable: "SCHEMATA",Name: "Database",OriginalName: "SCHEMA_NAME",Charset: 33,Length: 192,Type: 253,Flags: 1,Decimals: 0,},},Rows: rows,}c.Conn.Write(mr.GetBytes())} else if strings.HasPrefix(sql, "SELECT @@character_set_database, @@collation_database") {rows := [][]string{{"utf8mb4", "utf8mb4_general_ci"},}mr := MysqlResponse{Fs: []FieldProtocol{{Catalog: "def",Database: "information_schema",Table: "SCHEMATA",OriginalTable: "SCHEMATA",Name: "@@character_set_database",OriginalName: "@@character_set_database",Charset: 33,Length: 192,Type: 253,Flags: 1,Decimals: 0,},{Catalog: "def",Database: "information_schema",Table: "SCHEMATA",OriginalTable: "SCHEMATA",Name: "@@collation_database",OriginalName: "@@collation_database",Charset: 33,Length: 192,Type: 253,Flags: 1,Decimals: 0,},},Rows: rows,}c.Conn.Write(mr.GetBytes())} else if strings.HasPrefix(sql, "SHOW FULL TABLES") {if c.SchemaName == "" {c.writeErrorPacket(1046, "3D000", "未选择数据库")return nil}dbPath := filepath.Join(dbRoot, c.SchemaName)entries, err := os.ReadDir(dbPath)if err != nil {c.writeErrorPacket(1049, "42000", fmt.Sprintf("Unknown database '%s'", c.SchemaName))return nil}var rows [][]stringfor _, entry := range entries {if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".mgx") {tableName := strings.TrimSuffix(entry.Name(), ".mgx")rows = append(rows, []string{tableName, "BASE TABLE"})}}mr := MysqlResponse{Fs: []FieldProtocol{{Catalog: "def",Database: "information_schema",Table: "SCHEMATA",OriginalTable: "SCHEMATA",Name: "Tables_in_" + c.SchemaName,OriginalName: "Tables_in_" + c.SchemaName,Charset: 33,Length: 192,Type: 253,Flags: 1,Decimals: 0,},{Catalog: "def",Database: "information_schema",Table: "SCHEMATA",OriginalTable: "SCHEMATA",Name: "Table_type",OriginalName: "Table_type",Charset: 33,Length: 192,Type: 253,Flags: 1,Decimals: 0,},},Rows: rows,}c.Conn.Write(mr.GetBytes())} else {c.writeOKPacket(0, 0, 0, 0)}case ComPing:c.writeOKPacket(0, 0, 0, 0)case ComSetOption:default:}return nil
}// collapseSpaces 压缩字符串中的空格
func collapseSpaces(input string) string {var buf bytes.Bufferreader := strings.NewReader(input)prevWasSpace := falsefor {r, _, err := reader.ReadRune()if err != nil {break}if r == ' ' {if !prevWasSpace {buf.WriteRune(r)}prevWasSpace = true} else {buf.WriteRune(r)prevWasSpace = false}}return buf.String()
}// writeOKPacket 写入OK数据包
func (ch *ConnHandle) writeOKPacket(affectedRows, lastInsertID uint64, flags uint16, warnings uint16) error {buf := bytes.NewBuffer([]byte{})buf.WriteByte(OKPacket)binary.Write(buf, binary.LittleEndian, affectedRows)binary.Write(buf, binary.LittleEndian, lastInsertID)binary.Write(buf, binary.LittleEndian, flags)binary.Write(buf, binary.LittleEndian, warnings)return ch.writePacket(buf.Bytes())
}// writeErrorPacket 写入错误数据包
func (ch *ConnHandle) writeErrorPacket(errorCode uint16, sqlState, errMessage string) error {buf := bytes.NewBuffer([]byte{})buf.WriteByte(ErrPacket)binary.Write(buf, binary.LittleEndian, errorCode)buf.WriteByte('#')buf.WriteString(sqlState)buf.WriteString(errMessage)return ch.writePacket(buf.Bytes())
}// parseClientHandshakePacket 解析客户端握手数据包
func (ch *ConnHandle) parseClientHandshakePacket(firstTime bool, data []byte) (string, string, []byte, error) {pos := 0clientFlags, pos, ok := readUint32(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取客户端标志")}if clientFlags&CapabilityClientProtocol41 == 0 {return "", "", nil, errors.New("parseClientHandshakePacket: 仅支持协议4.1")}if firstTime {ch.Capabilities = clientFlags & (CapabilityClientDeprecateEOF | CapabilityClientFoundRows)}if clientFlags&CapabilityClientMultiStatements > 0 {ch.Capabilities |= CapabilityClientMultiStatements}_, pos, ok = readUint32(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取maxPacketSize")}characterSet, pos, ok := readByte(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取characterSet")}ch.CharacterSet = characterSetpos += 23username, pos, ok := readNullString(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取username")}var authResponse []byte// 只处理安全连接方式的身份验证响应if clientFlags&CapabilityClientSecureConnection != 0 {var l bytel, pos, ok = readByte(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response长度")}authResponse, pos, ok = readBytesCopy(data, pos, int(l))if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response")}} else {a := ""a, pos, ok = readNullString(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response")}authResponse = []byte(a)}if clientFlags&CapabilityClientConnectWithDB != 0 {dbname := ""dbname, pos, ok = readNullString(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取dbname")}ch.SchemaName = dbname}authMethod := MysqlNativePasswordif clientFlags&CapabilityClientPluginAuth != 0 {authMethod, pos, ok = readNullString(data, pos)if !ok {return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取authMethod")}}if authMethod == "" {authMethod = MysqlNativePassword}if clientFlags&CapabilityClientConnAttr != 0 {if _, _, err := parseConnAttrs(data, pos); err != nil {Loger.Error("解码客户端发送的连接属性: ", err)}}return username, authMethod, authResponse, nil
}// ReadOnePack 读取一个数据包
func (ch *ConnHandle) ReadOnePack() ([]byte, error) {var r io.Reader = ch.Conn.(io.Reader)length, err := ch.readHeaderFrom(r)if err != nil {return nil, err}if length < MaxPacketSize {buf := make([]byte, length)if _, err := io.ReadFull(r, buf); err != nil {return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")}return buf, nil}return nil, errors.New("readEphemeralPacketDirect doesn't support more than one packet")
}// WriteOnePack 写入一个数据包
func (ch *ConnHandle) WriteOnePack() ([]byte, error) {salt, _ := NewSalt()p01 := Packet01{Ver: 10,VerSion: ServerVersion,ServerId: <-tidchan,Salt: salt,SerFlag1: []byte{255, 247},Bm: 28,SerType: 0,SerFlag2: []byte{15, 128},}err := ch.writePacket(p01.GetBytes())return salt, err
}// writePacket 写入数据包
func (ch *ConnHandle) writePacket(data []byte) error {index := 0length := len(data)w := ch.Conn.(*net.TCPConn)for {packetLength := lengthif packetLength > MaxPacketSize {packetLength = MaxPacketSize}var header [4]byteheader[0] = byte(packetLength)header[1] = byte(packetLength >> 8)header[2] = byte(packetLength >> 16)header[3] = ch.sequenceif n, err := w.Write(header[:]); err != nil {return errors.New("Write(header) failed")} else if n != 4 {return errors.New("Write(header) returned a short write: < 4")}if n, err := w.Write(data[index : index+packetLength]); err != nil {return errors.New("Write(packet) failed")} else if n != packetLength {return errors.New("Write(packet) returned a short write")}ch.sequence++length -= packetLengthif length == 0 {if packetLength == MaxPacketSize {header[0] = 0header[1] = 0header[2] = 0header[3] = ch.sequenceif n, err := w.Write(header[:]); err != nil {return errors.New("Write(empty header) failed")} else if n != 4 {return errors.New("Write(empty header) returned a short write")}ch.sequence++}return nil}index += packetLength}
}// readHeaderFrom 从读取器中读取头部
func (ch *ConnHandle) readHeaderFrom(r io.Reader) (int, error) {var header [4]byteif _, err := io.ReadFull(r, header[:]); err != nil {if err == io.EOF {return 0, err}if strings.HasSuffix(err.Error(), "read: connection reset by peer") {return 0, io.EOF}return 0, errors.New("io.ReadFull(header size) failed")}sequence := uint8(header[3])if sequence != ch.sequence {return 0, errors.New("invalid sequence: expected " + strconv.Itoa(int(ch.sequence)) + ", got " + strconv.Itoa(int(sequence)))}ch.sequence++return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16), nil
}// readEphemeralPacket 读取临时数据包
func (c *ConnHandle) readEphemeralPacket() ([]byte, error) {var r io.Reader = c.Conn.(io.Reader)length, err := c.readHeaderFrom(r)if err != nil {return nil, err}if length == 0 {return nil, nil}data := make([]byte, length)if _, err := io.ReadFull(r, data); err != nil {return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")}if length < MaxPacketSize {return data, nil}for {next, err := c.readOnePacket()if err != nil {return nil, err}if len(next) == 0 {break}data = append(data, next...)if len(next) < MaxPacketSize {break}}return data, nil
}// readOnePacket 读取一个数据包
func (ch *ConnHandle) readOnePacket() ([]byte, error) {var r io.Reader = ch.Conn.(io.Reader)length, err := ch.readHeaderFrom(r)if err != nil {return nil, err}if length == 0 {return nil, nil}data := make([]byte, length)if _, err := io.ReadFull(r, data); err != nil {return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")}return data, nil
}// parseConnAttrs 解析连接属性
func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {var attrLen uint64attrLen, pos, ok := readLenEncInt(data, pos)if !ok {return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性变量长度")}var attrLenRead uint64attrs := make(map[string]string)for attrLenRead < attrLen {var keyLen bytekeyLen, pos, ok = readByte(data, pos)if !ok {return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性键长度")}attrLenRead += uint64(keyLen) + 1var connAttrKey []byteconnAttrKey, pos, ok = readBytesCopy(data, pos, int(keyLen))if !ok {return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性键")}var valLen bytevalLen, pos, ok = readByte(data, pos)if !ok {return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性值长度")}attrLenRead += uint64(valLen) + 1var connAttrVal []byteconnAttrVal, pos, ok = readBytesCopy(data, pos, int(valLen))if !ok {return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性值")}attrs[string(connAttrKey[:])] = string(connAttrVal[:])}return attrs, pos, nil
}// isPassScrambleMysqlNativePassword 验证密码是否匹配
func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword string) bool {if len(reply) == 0 {return false}if mysqlNativePassword == "" {return false}mysqlNativePassword = NativePassword(mysqlNativePassword)if strings.Contains(mysqlNativePassword, "*") {mysqlNativePassword = mysqlNativePassword[1:]}hash, err := hex.DecodeString(mysqlNativePassword)if err != nil {return false}crypt := sha1.New()crypt.Write(salt)crypt.Write(hash)scramble := crypt.Sum(nil)for i := range scramble {scramble[i] ^= reply[i]}hashStage1 := scramblecrypt.Reset()crypt.Write(hashStage1)candidateHash2 := crypt.Sum(nil)return bytes.Equal(candidateHash2, hash)
}// NativePassword 生成原生密码格式
func NativePassword(password string) string {if len(password) == 0 {return ""}hash := sha1.New()hash.Write([]byte(password))s1 := hash.Sum(nil)hash.Reset()hash.Write(s1)s2 := hash.Sum(nil)s := strings.ToUpper(hex.EncodeToString(s2))return fmt.Sprintf("*%s", s)
}// NewSalt 生成新的盐值
func NewSalt() ([]byte, error) {salt := make([]byte, 20)if _, err := rand.Read(salt); err != nil {return nil, err}for i := 0; i < len(salt); i++ {salt[i] &= 0x7fif salt[i] == '\x00' || salt[i] == '$' {salt[i]++}}return salt, nil
}// Packet01 数据包结构type Packet01 struct {Ver byteVerSion stringServerId uint32Salt []byteSerFlag1 []byteBm int8SerType int16SerFlag2 []byte
}// GetBytes 获取数据包的字节表示
func (p01 *Packet01) GetBytes() []byte {b := bytes.NewBuffer([]byte{})b.WriteByte(p01.Ver)b.WriteString(p01.VerSion)b.WriteByte(byte(0))binary.Write(b, binary.BigEndian, p01.ServerId)b.Write(p01.Salt[:8])b.WriteByte(byte(0))binary.Write(b, binary.BigEndian, p01.SerFlag1)b.WriteByte(byte(p01.Bm))binary.Write(b, binary.BigEndian, p01.SerType)binary.Write(b, binary.BigEndian, p01.SerFlag2)b.WriteByte(byte(21))b.Write(bytes.Repeat([]byte{0}, 10))b.Write(p01.Salt[8:])b.WriteByte(byte(0))b.WriteString(MysqlNativePassword)b.WriteByte(byte(0))return b.Bytes()
}// MysqlResponse MySQL响应结构
type MysqlResponse struct {Fs []FieldProtocolRows [][]string
}// FieldProtocol 字段协议结构
type FieldProtocol struct {Catalog stringDatabase stringTable stringOriginalTable stringName stringOriginalName stringCharset intLength intType intFlags intDecimals int
}// GetBytes 获取字段协议的字节表示
func (fp *FieldProtocol) GetBytes(pk int) []byte {b := bytes.NewBuffer([]byte{0, 0, 0, 0})lt := EncodeLength(len(fp.Catalog))b.Write(lt)b.WriteString(fp.Catalog)lt = EncodeLength(len(fp.Database))b.Write(lt)b.WriteString(fp.Database)lt = EncodeLength(len(fp.Table))b.Write(lt)b.WriteString(fp.Table)lt = EncodeLength(len(fp.OriginalTable))b.Write(lt)b.WriteString(fp.OriginalTable)lt = EncodeLength(len(fp.Name))b.Write(lt)b.WriteString(fp.Name)lt = EncodeLength(len(fp.OriginalName))b.Write(lt)b.WriteString(fp.OriginalName)b.WriteByte(0x0c)b.WriteByte(byte(fp.Charset))b.WriteByte(0x00)b.Write(intToBytesLittleEndian(fp.Length))b.WriteByte(byte(fp.Type))lt = intToBytesLittleEndian(fp.Flags)b.WriteByte(lt[0])b.WriteByte(lt[1])b.WriteByte(byte(fp.Decimals))b.WriteByte(0x00)b.WriteByte(0x00)bs := b.Bytes()lb := intToBytesLittleEndian(len(bs) - 4)bs[0] = byte(lb[0])bs[1] = byte(lb[1])bs[2] = byte(lb[2])bs[3] = byte(pk)return bs
}// GetBytes 获取MySQL响应的字节表示
func (mr *MysqlResponse) GetBytes() []byte {pk := 1b := bytes.NewBuffer([]byte{})numberFields := EncodeLength(len(mr.Fs))numberPackets := intToBytesLittleEndian(len(numberFields))numberPackets[3] = byte(pk)b.Write(numberPackets)b.Write(numberFields)for _, f := range mr.Fs {pk++if pk > 255 {pk = 0}b.Write(f.GetBytes(pk))}pk++if pk > 255 {pk = 0}b.Write(getEOFPacket(pk))for _, r := range mr.Rows {pk++if pk > 255 {pk = 0}b.Write(getRowPacket(r, pk))}pk++if pk > 255 {pk = 0}b.Write(getEOFPacket(pk))return b.Bytes()
}// getRowPacket 获取行数据包
func getRowPacket(strs []string, pk int) []byte {b := bytes.NewBuffer([]byte{0, 0, 0, 0})for _, s := range strs {lt := EncodeLength(len(s))b.Write(lt)b.WriteString(s)}bs := b.Bytes()lb := intToBytesLittleEndian(len(bs) - 4)bs[0] = byte(lb[0])bs[1] = byte(lb[1])bs[2] = byte(lb[2])bs[3] = byte(pk)return bs
}// getEOFPacket 获取EOF数据包
func getEOFPacket(pk int) []byte {return []byte{0x05, 0x00, 0x00, byte(pk), 0xFE, 0x00, 0x00, 0x22, 0x00}
}// EncodeLength 编码长度
func EncodeLength(length int) []byte {var buf bytes.Bufferif length < 251 {buf.WriteByte(byte(length))} else if length <= 65535 {buf.WriteByte(0xfd)binary.Write(&buf, binary.LittleEndian, uint16(length))} else {buf.WriteByte(0xfe)binary.Write(&buf, binary.LittleEndian, uint64(length))}return buf.Bytes()
}// intToBytesLittleEndian 将整数转换为小端字节序
func intToBytesLittleEndian(num int) []byte {var buf bytes.Bufferbinary.Write(&buf, binary.LittleEndian, uint32(num))return buf.Bytes()
}// readByte 读取一个字节
func readByte(data []byte, pos int) (byte, int, bool) {if pos >= len(data) {return 0, 0, false}return data[pos], pos + 1, true
}// readBytesCopy 读取并复制字节
func readBytesCopy(data []byte, pos int, size int) ([]byte, int, bool) {if pos+size > len(data) {return nil, 0, false}result := make([]byte, size)copy(result, data[pos:pos+size])return result, pos + size, true
}// readNullString 读取以null结尾的字符串
func readNullString(data []byte, pos int) (string, int, bool) {end := bytes.IndexByte(data[pos:], 0)if end == -1 {return "", 0, false}return string(data[pos : pos+end]), pos + end + 1, true
}// readUint32 读取uint32
func readUint32(data []byte, pos int) (uint32, int, bool) {if pos+4 > len(data) {return 0, 0, false}return binary.LittleEndian.Uint32(data[pos : pos+4]), pos + 4, true
}// readLenEncInt 读取长度编码的整数
func readLenEncInt(data []byte, pos int) (uint64, int, bool) {if pos >= len(data) {return 0, 0, false}switch data[pos] {case 0xfc:if pos+3 > len(data) {return 0, 0, false}return uint64(data[pos+1]) | uint64(data[pos+2])<<8, pos + 3, truecase 0xfd:if pos+4 > len(data) {return 0, 0, false}return uint64(data[pos+1]) |uint64(data[pos+2])<<8 |uint64(data[pos+3])<<16, pos + 4, truecase 0xfe:if pos+9 > len(data) {return 0, 0, false}return uint64(data[pos+1]) |uint64(data[pos+2])<<8 |uint64(data[pos+3])<<16 |uint64(data[pos+4])<<24 |uint64(data[pos+5])<<32 |uint64(data[pos+6])<<40 |uint64(data[pos+7])<<48 |uint64(data[pos+8])<<56, pos + 9, true}return uint64(data[pos]), pos + 1, true
}
写在最后:当程序员被逼急了…
这个项目让我深刻体会到:程序员的创造力往往来自于解决同事的"奇怪需求"。谁能想到,一个易语言的线程安全问题,最终会催生一个MySQL协议服务器?
所以,下次如果你的同事提出一个看似不合理的需求,不妨换个角度想想——也许这正是你提升技能、探索未知领域的好机会!
最后,如果你也有类似的"被逼无奈"的开发经历,欢迎在评论区分享,让我们一起吐槽,不对,让我们一起学习!
P.S. 同事的易语言项目现在运行得很稳定,他对我说:“还有很多地方可以这样优化,帮我再写几个接口吧!” 我沉默不语…