[手写系列]Go手写db — — 第六版(实现表连接)
[手写系列]Go手写db — — 第六版(实现表连接)
第一版文章:[手写系列]Go手写db — — 完整教程_go手写数据库-CSDN博客
第二版文章:[手写系列]Go手写db — — 第二版-CSDN博客
第三版文章:[手写系列]Go手写db — — 第三版(实现分组、排序、聚合函数等)-CSDN博客
第四版文章:[手写系列]Go手写db — — 第四版(实现事务、网络模块)
第五版文章:[手写系列]Go手写db — — 第五版(实现数据库操作模块)
整体项目Github地址:https://github.com/ziyifast/ZiyiDB
- 请大家多多支持,也欢迎大家star⭐️和共同维护这个项目~
本文主要介绍如何在 ZiyiDB 第五版的基础上,实现表连接(JOIN)功能,包括INNER JOIN、LEFT JOIN、RIGHT JOIN等,通过这些功能,将使得ZiyiDB成为一个更完整的数据库系统。
一、功能列表
主要实现两个表之间的连接操作。
- 新增对INNER JOIN的支持
- 新增对LEFT JOIN的支持
- 新增对RIGHT JOIN的支持
- 支持带WHERE条件的JOIN查询
二、实现细节
功能点一:实现join操作
实现思路
- internal/lexer/token.go新增表关联标识符
- internal/lexer/lexer.go词法分析器
- lookupIdentifier方法新增case,将用户传入的SQL转换为TokenType
- NextToken方法新增case,用于识别表名.字段名
- readNumber方法也新增判断,用于辅助对表名.字段名的支持
- internal/ast/ast.go抽象语法树新增JoinClause类型,同时SelectStatement结构体新增JoinClause字段
- internal/parser/parser.go语法解析器
- parseSelectStatement方法新增对Join子句的判断
- 新增parseJoinClause方法,解析Join字句,并将其转换为抽象语法树结构
- internal/storage/memory.go底层存储引擎
- 新增selectWithJoin方法
- Select方法里新增判断,如果语法树包含Join部分,那么就走selectWithJoin逻辑
- getColumnValue值新增对表名.列名的处理
代码实现
1. 词法分析器调整
- internal/lexer/token.go新增对应关键字
INNER TokenType = "INNER"
LEFT TokenType = "LEFT"
RIGHT TokenType = "RIGHT"
JOIN TokenType = "JOIN"
ON TokenType = "ON"
DOT TokenType = "." // 点操作符,用于表名.列名
- internal/lexer/lexer.go新增对应case以及优化readNumber函数以支持对表名.字段名的读取
// lookupIdentifier 查找标识符类型
// 将标识符转换为对应的标记类型
// 识别 SQL 关键字
func (l *Lexer) lookupIdentifier(ident string) TokenType {switch strings.ToUpper(ident) {...case "INNER":return INNERcase "LEFT":return LEFTcase "RIGHT":return RIGHTcase "JOIN":return JOINcase "ON":return ONdefault:return IDENT}
}// NextToken 获取下一个词法单元
func (l *Lexer) NextToken() Token {var tok Token// 跳过空白字符l.skipWhitespace()// 检查是否为注释if l.ch == '-' && l.peekChar() == '-' {return l.readComment()}switch l.ch {case '=', '>', '<', '!':// 处理操作符tok = l.readOperator()return tokcase ',':tok = Token{Type: COMMA, Literal: ","}case ';':tok = Token{Type: SEMI, Literal: ";"}case '(':tok = Token{Type: LPAREN, Literal: "("}case ')':tok = Token{Type: RPAREN, Literal: ")"}case '*':tok = Token{Type: ASTERISK, Literal: "*"}case '.':// 检查是否是数字的一部分(小数点)if isDigit(l.peekChar()) {// 以点开头的小数,如 .123num := l.readNumber()if strings.Contains(num, ".") {tok.Type = FLOAT} else {tok.Type = INT}tok.Literal = numreturn tok} else {// 单独的点操作符,用于表名.列名tok = Token{Type: DOT, Literal: "."}}...l.readChar()return tok
}// 读取数字:支持对浮点数的读取
func (l *Lexer) readNumber() string {var num bytes.Buffer// 如果以点开始,但下一个字符不是数字,则不是数字if l.ch == '.' && !isDigit(l.peekChar()) {// 这是一个点操作符,不是数字的一部分return ""}hasDecimal := falsefor (isDigit(l.ch) || (l.ch == '.' && !hasDecimal)) && l.ch != 0 {if l.ch == '.' {hasDecimal = true}num.WriteRune(l.ch)l.readChar()}return num.String()
}
...
2. 抽象语法树调整
internal/ast/ast.go:
// JoinClause 表示JOIN子句
type JoinClause struct {Token lexer.TokenJoinType string // "INNER", "LEFT", "RIGHT"TableName string // 右表名On Expression // ON条件
}func (jc *JoinClause) expressionNode() {}
func (jc *JoinClause) TokenLiteral() string { return jc.Token.Literal }// SelectStatement 表示SELECT语句
// 表示 SELECT 查询语句
// 包含选择的字段、表名和 WHERE 条件
type SelectStatement struct {Token lexer.TokenFields []ExpressionTableName stringJoin *JoinClause // 添加JOIN子句,支持表关联查询Where ExpressionGroupBy []Expression // 添加 GroupBy 字段,新增对group by 分组的支持OrderBy []OrderByClause // 添加 OrderBy 字段,新增对order by 排序的支持
}
3. 语法解析器调整
internal/parser/parser.go:
// JoinClause 表示JOIN子句
type JoinClause struct {Token lexer.TokenJoinType string // "INNER", "LEFT", "RIGHT"TableName string // 右表名On Expression // ON条件
}func (jc *JoinClause) expressionNode() {}
func (jc *JoinClause) TokenLiteral() string { return jc.Token.Literal }// SelectStatement 表示SELECT语句
// 表示 SELECT 查询语句
// 包含选择的字段、表名和 WHERE 条件
type SelectStatement struct {Token lexer.TokenFields []ExpressionTableName stringJoin *JoinClause // 添加JOIN子句,支持表关联查询Where ExpressionGroupBy []Expression // 添加 GroupBy 字段,新增对group by 分组的支持OrderBy []OrderByClause // 添加 OrderBy 字段,新增对order by 排序的支持
}
4. 存储引擎调整
internal/storage/memory.go:
// Select 查询数据,支持事务
func (b *MemoryBackend) Select(databaseName string, stmt *ast.SelectStatement, txn *Transaction) (*Results, error) {b.Mu.RLock()db, dbExists := b.Databases[databaseName]b.Mu.RUnlock()if !dbExists {return nil, fmt.Errorf("database '%s' does not exist", databaseName)}db.mu.RLock()table, tableExists := db.Tables[stmt.TableName]db.mu.RUnlock()if !tableExists {return nil, fmt.Errorf("table '%s' doesn't exist in database '%s'", stmt.TableName, databaseName)}// 判断是否有表连接操作if stmt.Join != nil {return b.selectWithJoin(databaseName, stmt, txn)}results := &Results{Columns: make([]ResultColumn, 0),Rows: make([][]Cell, 0),}...return results, nil
}// selectWithJoin 处理JOIN查询
func (b *MemoryBackend) selectWithJoin(databaseName string, stmt *ast.SelectStatement, txn *Transaction) (*Results, error) {b.Mu.RLock()db, dbExists := b.Databases[databaseName]b.Mu.RUnlock()if !dbExists {return nil, fmt.Errorf("database '%s' does not exist", databaseName)}// 获取左表db.mu.RLock()leftTable, leftTableExists := db.Tables[stmt.TableName]db.mu.RUnlock()if !leftTableExists {return nil, fmt.Errorf("table '%s' doesn't exist in database '%s'", stmt.TableName, databaseName)}// 获取右表db.mu.RLock()rightTable, rightTableExists := db.Tables[stmt.Join.TableName]db.mu.RUnlock()if !rightTableExists {return nil, fmt.Errorf("table '%s' doesn't exist in database '%s'", stmt.Join.TableName, databaseName)}// 构建结果列results := &Results{Columns: make([]ResultColumn, 0),Rows: make([][]Cell, 0),}// 创建列映射,用于后续过滤行数据columnMap := make(map[string]int) // 列名到索引的映射// 处理SELECT * 情况if len(stmt.Fields) == 1 {if _, ok := stmt.Fields[0].(*ast.StarExpression); ok {// 添加左表的所有列for i, col := range leftTable.Columns {columnName := fmt.Sprintf("%s.%s", stmt.TableName, col.Name)results.Columns = append(results.Columns, ResultColumn{Name: columnName,Type: col.Type,})columnMap[columnName] = i}// 添加右表的所有列for i, col := range rightTable.Columns {columnName := fmt.Sprintf("%s.%s", stmt.Join.TableName, col.Name)results.Columns = append(results.Columns, ResultColumn{Name: columnName,Type: col.Type,})columnMap[columnName] = len(leftTable.Columns) + i}}} else {// 处理具体列的选择for _, expr := range stmt.Fields {if identifier, ok := expr.(*ast.Identifier); ok {// 处理带表名前缀的列名 (table.column)parts := strings.Split(identifier.Value, ".")var tableName, columnName stringif len(parts) == 2 {tableName = parts[0]columnName = parts[1]} else {// 不带表名前缀的列名columnName = identifier.Value// 先在左表中查找found := falsefor _, col := range leftTable.Columns {if col.Name == columnName {tableName = stmt.TableNamefound = truebreak}}// 如果左表中没找到,在右表中查找if !found {for _, col := range rightTable.Columns {if col.Name == columnName {tableName = stmt.Join.TableNamefound = truebreak}}}if !found {return nil, fmt.Errorf("Unknown column '%s' in 'field list'", identifier.Value)}}// 查找对应的表和列var _ *Tablevar tableColumns []ast.ColumnDefinitionif tableName == stmt.TableName {_ = leftTabletableColumns = leftTable.Columns} else if tableName == stmt.Join.TableName {_ = rightTabletableColumns = rightTable.Columns} else {return nil, fmt.Errorf("Unknown table '%s' in field list", tableName)}found := falsefor i, col := range tableColumns {if col.Name == columnName {results.Columns = append(results.Columns, ResultColumn{Name: identifier.Value, // 保持原始名称(可能包含表前缀)Type: col.Type,})// 计算在组合行中的索引位置if tableName == stmt.TableName {columnMap[identifier.Value] = i} else {columnMap[identifier.Value] = len(leftTable.Columns) + i}found = truebreak}}if !found {return nil, fmt.Errorf("Unknown column '%s' in 'field list'", identifier.Value)}}}}// 获取左表和右表的所有行leftRows := make([][]Cell, 0)rightRows := make([][]Cell, 0)// 获取左表可见行for _, row := range leftTable.Rows {visibleRow := b.getVisibleRow(row, txn)if visibleRow != nil {leftRows = append(leftRows, visibleRow)}}// 获取右表可见行for _, row := range rightTable.Rows {visibleRow := b.getVisibleRow(row, txn)if visibleRow != nil {rightRows = append(rightRows, visibleRow)}}// 执行JOIN操作var joinedRows [][]Cellswitch stmt.Join.JoinType {case "INNER":joinedRows = b.innerJoin(leftRows, rightRows, leftTable, rightTable, stmt.Join.On, stmt.Where)case "LEFT":joinedRows = b.leftJoin(leftRows, rightRows, leftTable, rightTable, stmt.Join.On, stmt.Where)case "RIGHT":joinedRows = b.rightJoin(leftRows, rightRows, leftTable, rightTable, stmt.Join.On, stmt.Where)default:return nil, fmt.Errorf("Unsupported JOIN type: %s", stmt.Join.JoinType)}// 过滤行数据,只保留SELECT子句中指定的列_, ok := stmt.Fields[0].(*ast.StarExpression)if !(len(stmt.Fields) == 1 && ok) {// 如果不是SELECT *,则需要过滤列filteredRows := make([][]Cell, len(joinedRows))for i, row := range joinedRows {filteredRow := make([]Cell, len(results.Columns))for j, col := range results.Columns {if colIndex, exists := columnMap[col.Name]; exists && colIndex < len(row) {filteredRow[j] = row[colIndex]} else {// 如果找不到列,设置为默认值filteredRow[j] = Cell{Type: CellTypeText, TextValue: "NULL"}}}filteredRows[i] = filteredRow}results.Rows = filteredRows} else {// 如果是SELECT *,直接使用所有列results.Rows = joinedRows}// 处理 ORDER BYif len(stmt.OrderBy) > 0 {var err errorresults.Rows, err = b.orderBy(results.Rows, results.Columns, stmt.OrderBy, append(leftTable.Columns, rightTable.Columns...))if err != nil {return nil, err}}return results, nil
}// innerJoin 执行INNER JOIN
func (b *MemoryBackend) innerJoin(leftRows, rightRows [][]Cell, leftTable, rightTable *Table, onCondition, whereCondition ast.Expression) [][]Cell {resultRows := make([][]Cell, 0)for _, leftRow := range leftRows {for _, rightRow := range rightRows {// 构建组合行用于条件评估combinedRow := append(leftRow, rightRow...)combinedColumns := append(leftTable.Columns, rightTable.Columns...)// 判断是否满足ON条件match, err := evaluateWhereCondition(onCondition, combinedRow, combinedColumns)if err != nil {//如果不满足,则跳过continue}if match {// 如果有WHERE条件,也需满足if whereCondition != nil {whereMatch, err := evaluateWhereCondition(whereCondition, combinedRow, combinedColumns)if err != nil || !whereMatch {continue}}resultRows = append(resultRows, combinedRow)}}}return resultRows
}// leftJoin 执行LEFT JOIN
func (b *MemoryBackend) leftJoin(leftRows, rightRows [][]Cell, leftTable, rightTable *Table, onCondition, whereCondition ast.Expression) [][]Cell {resultRows := make([][]Cell, 0)for _, leftRow := range leftRows {matched := falsefor _, rightRow := range rightRows {// 构建组合行用于条件评估combinedRow := append(leftRow, rightRow...)combinedColumns := append(leftTable.Columns, rightTable.Columns...)// 评估ON条件match, err := evaluateWhereCondition(onCondition, combinedRow, combinedColumns)if err != nil {// 如果评估出错,跳过这一对行continue}if match {matched = true// 如果有WHERE条件,也需满足if whereCondition != nil {whereMatch, err := evaluateWhereCondition(whereCondition, combinedRow, combinedColumns)if err != nil || !whereMatch {continue}}resultRows = append(resultRows, combinedRow)}}// 如果没有匹配的右行,添加左行和NULL值的右行if !matched {nullRightRow := make([]Cell, len(rightTable.Columns))for i := range nullRightRow {nullRightRow[i] = Cell{Type: CellTypeText, TextValue: "NULL"}}combinedRow := append(leftRow, nullRightRow...)// LEFT JOIN中,即使ON条件不匹配,也要考虑WHERE条件if whereCondition != nil {combinedColumns := append(leftTable.Columns, rightTable.Columns...)whereMatch, err := evaluateWhereCondition(whereCondition, combinedRow, combinedColumns)if err != nil || !whereMatch {continue}}resultRows = append(resultRows, combinedRow)}}return resultRows
}// rightJoin 执行RIGHT JOIN
func (b *MemoryBackend) rightJoin(leftRows, rightRows [][]Cell, leftTable, rightTable *Table, onCondition, whereCondition ast.Expression) [][]Cell {resultRows := make([][]Cell, 0)for _, rightRow := range rightRows {matched := falsefor _, leftRow := range leftRows {// 构建组合行用于条件评估combinedRow := append(leftRow, rightRow...)combinedColumns := append(leftTable.Columns, rightTable.Columns...)// 评估ON条件match, err := evaluateWhereCondition(onCondition, combinedRow, combinedColumns)if err != nil {// 如果评估出错,跳过这一对行continue}if match {matched = true// 如果有WHERE条件,也需满足if whereCondition != nil {whereMatch, err := evaluateWhereCondition(whereCondition, combinedRow, combinedColumns)if err != nil || !whereMatch {continue}}resultRows = append(resultRows, combinedRow)}}// 如果没有匹配的左行,添加NULL值的左行和右行if !matched {nullLeftRow := make([]Cell, len(leftTable.Columns))for i := range nullLeftRow {nullLeftRow[i] = Cell{Type: CellTypeText, TextValue: "NULL"}}combinedRow := append(nullLeftRow, rightRow...)// RIGHT JOIN中,即使ON条件不匹配,也要考虑WHERE条件if whereCondition != nil {combinedColumns := append(leftTable.Columns, rightTable.Columns...)whereMatch, err := evaluateWhereCondition(whereCondition, combinedRow, combinedColumns)if err != nil || !whereMatch {continue}}resultRows = append(resultRows, combinedRow)}}return resultRows
}
测试
测试命令:
-- 创建数据库并使用
create database test;
use test;
-- 创建测试表
CREATE TABLE users (id INT PRIMARY KEY, name TEXT, age INT);
CREATE TABLE orders (id INT PRIMARY KEY, user_id INT, product TEXT, amount FLOAT);-- 插入测试数据
INSERT INTO users VALUES (1, 'Alice', 25);
INSERT INTO users VALUES (2, 'Bob', 30);
INSERT INTO users VALUES (3, 'Charlie', 22);INSERT INTO orders VALUES (1, 1, 'Laptop', 1200.0);
INSERT INTO orders VALUES (2, 1, 'Mouse', 25.0);
INSERT INTO orders VALUES (3, 2, 'Keyboard', 75.0);
INSERT INTO orders VALUES (4, 4, 'Monitor', 300.0);
-- 测试INNER JOIN
SELECT users.name, orders.product, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id;-- 测试LEFT JOIN
SELECT users.name, orders.product, orders.amount FROM users LEFT JOIN orders ON users.id = orders.user_id;-- 测试RIGHT JOIN
SELECT users.name, orders.product, orders.amount FROM users RIGHT JOIN orders ON users.id = orders.user_id;
效果:
功能点二:实现带where的join操作
实现思路
目前表连接操作相关关键字和底层逻辑都已经实现,因此我们只需要在解析器中调整对where子句的解析即可。
internal/parser/parser.go解析器修改parseWhereClause方法,支持对where子句中包含表名.列名的解析
代码实现
internal/parser/parser.go:
// parseWhereClause 解析WHERE子句
func (p *Parser) parseWhereClause() (ast.Expression, error) {p.nextToken() //消费where关键字// 解析左操作数(列名)left, err := p.parseExpression() // 使用parseExpression来支持表名.列名if err != nil {return nil, err}// 检查下一个是否为 BETWEEN 操作符if p.peekTokenIs(lexer.BETWEEN) {p.nextToken()return p.parseBetweenExpression(left)}// 解析操作符p.nextToken()operator := p.curToken// 处理 LIKE 操作符if p.curTokenIs(lexer.LIKE) {p.nextToken()if !p.curTokenIs(lexer.STRING) {return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)}// 移除字符串字面量的引号pattern := p.curToken.Literalif len(pattern) >= 2 && (pattern[0] == '\'' || pattern[0] == '"') {pattern = pattern[1 : len(pattern)-1]}return &ast.LikeExpression{Token: operator,Left: left,Pattern: pattern,}, nil}// 处理其他操作符(包括 =、>、<、>=、<=、!= 等)if !p.isBasicOperator() {return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)}// 解析右操作数p.nextToken()right, err := p.parseExpression()if err != nil {return nil, err}return &ast.BinaryExpression{Token: operator,Left: left,Operator: operator.Literal,Right: right,}, nil
}
测试
测试命令:
-- 创建数据库并使用
create database test;
use test;
-- 创建测试表
CREATE TABLE users (id INT PRIMARY KEY, name TEXT, age INT);
CREATE TABLE orders (id INT PRIMARY KEY, user_id INT, product TEXT, amount FLOAT);-- 插入测试数据
INSERT INTO users VALUES (1, 'Alice', 25);
INSERT INTO users VALUES (2, 'Bob', 30);
INSERT INTO users VALUES (3, 'Charlie', 22);INSERT INTO orders VALUES (1, 1, 'Laptop', 1200.0);
INSERT INTO orders VALUES (2, 1, 'Mouse', 25.0);
INSERT INTO orders VALUES (3, 2, 'Keyboard', 75.0);
INSERT INTO orders VALUES (4, 4, 'Monitor', 300.0);-- 带WHERE条件的JOIN
SELECT users.name, orders.product, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id WHERE orders.amount > 50;
效果: