当前位置: 首页 > news >正文

[手写系列]Go手写db — — 第三版(实现分组、排序、聚合函数等)

[手写系列]Go手写db — — 第三版

第一版文章地址:https://blog.csdn.net/weixin_45565886/article/details/147839627
第二版文章地址:https://blog.csdn.net/weixin_45565886/article/details/150869791

  • 🏠整体项目Github地址:https://github.com/ziyifast/ZiyiDB
  • 🚀请大家多多支持,也欢迎大家star⭐️和共同维护这个项目~

序言:只要接触过后端开发,必不可少会使用到关系型数据库,比如:MySQL、Oracle等,那么我们经常使用的字段默认值、以及聚合函数底层是如何实现的呢?本文会给大家提供一些思路,实现相关功能。

主要介绍如何在 ZiyiDB之前的基础上,实现更多新功能,给大家提供实现数据库的简单思路,以及数据库底层实现的流程,后续更多功能,大家可以参考着实现。

一、功能列表

  1. 默认值支持(DEFAULT 关键字)
  2. 聚合函数支持(COUNT, SUM, AVG, MAX, MIN)
  3. Group by分组能力
  4. Order by 排序能力

二、实现细节

1. 默认值实现

设计思路

默认值是数据库中一个重要的数据完整性特性。当插入数据时,如果没有为某列提供值,数据库会自动使用该列的默认值。

在 ZiyiDB 中,默认值的实现需要考虑以下几点:

  • 语法解析:在 CREATE TABLE 语句中识别 DEFAULT 关键字和默认值
  • 存储:在表结构中保存每列的默认值
  • 执行:在 INSERT 语句中应用默认值

1.在lexer/token.go中新增default字符,然后在lexer/lexer.go的lookupIdentifier方法中新增对于default的case语句,用于匹配识别用户输入的SQL

token.go:
在这里插入图片描述
lexer.go:
在这里插入图片描述
2. internal/ast/ast.go抽象语法树中新增DefaultExpression,同时列定义中新增默认值字段,用于存储列的默认值
在这里插入图片描述
在这里插入图片描述
3. parser中的parseCreateTableStatement函数新增对create SQL中默认值的读取和封装,解析用户输入SQL中的字段默认值类型和value
在这里插入图片描述
4. internal/storage/memory.go 存储引擎处理Insert方法时,新增对默认值的处理。
在这里插入图片描述

代码实现

1.语法解析层(Parser)

在 internal/parser/parser.go 中,parseCreateTableStatement 方法被增强以支持默认值:

// parseCreateTableStatement 解析CREATE TABLE语句
func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) {stmt := &ast.CreateTableStatement{Token: p.curToken}// ... 其他代码// 解析列定义for !p.peekTokenIs(lexer.RPAREN) {p.nextToken()if !p.curTokenIs(lexer.IDENT) {return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)}col := ast.ColumnDefinition{Name: p.curToken.Literal,}if !p.expectPeek(lexer.INT) &&!p.expectPeek(lexer.TEXT) &&!p.expectPeek(lexer.FLOAT) &&!p.expectPeek(lexer.DATETIME) {return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)}col.Type = string(p.curToken.Type)if p.peekTokenIs(lexer.PRIMARY) {p.nextToken()if !p.expectPeek(lexer.KEY) {return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)}col.Primary = true}if p.peekTokenIs(lexer.DEFAULT) {p.nextToken() // 消费 DEFAULT 关键字p.nextToken() // 移动到默认值表达式开始位置// 解析复杂默认值表达式(支持函数调用、数学表达式等)defaultValue, err := p.parseExpression()if err != nil {return nil, fmt.Errorf("Invalid default value for column '%s': %v", col.Name, err)}// 创建 DefaultExpression 节点col.Default = &ast.DefaultExpression{Token: p.curToken,Value: defaultValue,}}stmt.Columns = append(stmt.Columns, col)if p.peekTokenIs(lexer.COMMA) {p.nextToken()}}// ... 其他代码
}

2.AST 定义

在 internal/ast/ast.go 中,我们添加了 DefaultExpression 类型来表示默认值:

// DefaultExpression 表示DEFAULT表达式
type DefaultExpression struct {Token lexer.TokenValue Expression
}func (de *DefaultExpression) expressionNode()      {}
func (de *DefaultExpression) TokenLiteral() string { return de.Token.Literal }

同时,ColumnDefinition 结构也被更新以包含默认值:

// ColumnDefinition 表示列定义
type ColumnDefinition struct {Name     stringType     stringPrimary  boolNullable boolDefault  interface{} //列默认值
}

3.存储引擎实现

在 internal/storage/memory.go 中,Insert 方法被增强以支持默认值:

// Insert 插入数据
func (b *MemoryBackend) Insert(stmt *ast.InsertStatement) error {table, exists := b.tables[stmt.TableName]if !exists {return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)}// 构建列名到表列索引的映射colIndexMap := make(map[string]int)for idx, col := range table.Columns {colIndexMap[col.Name] = idx}// 初始化行数据(长度为表的总列数)row := make([]ast.Cell, len(table.Columns))// 处理插入列列表(用户显式指定的列或隐式全列)var insertCols []*ast.Identifier//用户SQL需要插入的列名、值的映射userColMap := make(map[string]ast.Expression)if len(stmt.Columns) > 0 {insertCols = stmt.Columnsfor i, col := range stmt.Columns {userColMap[col.Token.Literal] = stmt.Values[i]}} else {// 未指定列时默认使用表的所有列insertCols = make([]*ast.Identifier, len(table.Columns))for i, col := range table.Columns {insertCols[i] = &ast.Identifier{Value: col.Name}userColMap[col.Name] = stmt.Values[i]}}// 检查值数量与指定列数量是否匹配if len(stmt.Values) != len(insertCols) {return fmt.Errorf("Column count doesn't match value count at row 1 (got %d, want %d)", len(stmt.Values), len(insertCols))}// 转换值// 填充行数据(处理用户值或默认值)for i, tableCol := range table.Columns {// 优先使用用户提供的值,否则使用默认值var expr ast.Expressionexpr = userColMap[tableCol.Name]if expr == nil && tableCol.Default != nil {expr = tableCol.Default.(*ast.DefaultExpression).Value}//获取当前列名colName := table.Columns[i].NametableColIdx, ok := colIndexMap[colName]if !ok {return fmt.Errorf("Unknown column '%s' in INSERT statement", colName)}// 转换值类型value, err := evaluateExpression(expr)if err != nil {return fmt.Errorf("invalid value for column '%s': %v", colName, err)}// 类型转换switch v := value.(type) {case string:if tableCol.Type == "INT" {intVal, err := strconv.ParseInt(v, 10, 32)if err != nil {return fmt.Errorf("Incorrect integer value: '%s' for column '%s'", v, tableCol.Name)}row[tableColIdx] = ast.Cell{Type: ast.CellTypeInt, IntValue: int32(intVal)}} else {row[tableColIdx] = ast.Cell{Type: ast.CellTypeText, TextValue: v}}case int32:row[tableColIdx] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}case float32:row[tableColIdx] = ast.Cell{Type: ast.CellTypeFloat, FloatValue: v}case time.Time:row[tableColIdx] = ast.Cell{Type: ast.CellTypeDateTime, TimeValue: v.Format("2006-01-02 15:04:05")}default:return fmt.Errorf("Unsupported value type: %T for column '%s'", value, tableCol.Name)}}// ... 其他代码
}

测试

测试SQL:

-- 创建带默认值的表
CREATE TABLE users (id INT PRIMARY KEY,name TEXT,age INT DEFAULT 18,score FLOAT,ctime DATETIME DEFAULT '2023-07-04 12:00:00'
);-- 插入部分列数据(未指定的列将使用默认值)
INSERT INTO users (id, name, score) VALUES (1, 'Alice', 90.0);
INSERT INTO users (id, name, age, score) VALUES (2, 'Bob', 25, 85.5);-- 查询数据验证默认值
SELECT * FROM users;

效果:
在这里插入图片描述

2. 聚合函数实现

设计思路

聚合函数是 SQL 中用于对一组值执行计算并返回单个值的函数。在 ZiyiDB 中,我们实现了以下聚合函数:

  • COUNT:计算行数
  • SUM:计算数值列的总和
  • AVG:计算数值列的平均值
  • MAX:找出列中的最大值
  • MIN:找出列中的最小值

聚合函数的实现需要考虑以下几点:
语法解析:在 SELECT 语句中识别函数调用
执行逻辑:在存储引擎中计算聚合结果
结果返回:以统一的格式返回结果

这里以count聚合函数为例,其他聚合函数同理

  1. internal/ast/ast.go中新增FunctionCall函数调用类型,用于后续执行函数调用,比如count、max等聚合函数
    在这里插入图片描述
  2. internal/parser/parser.go中新增对函数类型的解析和封装
    在这里插入图片描述
  3. internal/storage/memory.go存储引擎Select方法中新增对聚合函数的判断
    在这里插入图片描述
    同时memory.go中添加calculateFunctionResults方法,实现对函数的执行和底层实现
    在这里插入图片描述
    在这里插入图片描述

代码实现

  1. 语法解析层(Parser)

在 internal/parser/parser.go 中,我们增强了 parseSelectStatement 方法来支持函数调用:

// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {stmt := &ast.SelectStatement{Token: p.curToken}// 解析选择列表for !p.peekTokenIs(lexer.FROM) {p.nextToken()if p.curToken.Type == lexer.ASTERISK {stmt.Fields = append(stmt.Fields, &ast.StarExpression{})break}expr, err := p.parseExpression()if err != nil {return nil, err}stmt.Fields = append(stmt.Fields, expr)if p.peekTokenIs(lexer.COMMA) {p.nextToken()}}// ... 其他代码
}

parseExpression 方法也进行了增强,以支持函数调用的解析:

// parseExpression 解析表达式
func (p *Parser) parseExpression() (ast.Expression, error) {switch p.curToken.Type {// ... 其他情况case lexer.IDENT:if p.peekTokenIs(lexer.LPAREN) {return p.parseFunctionCall()}return &ast.Identifier{Token: p.curToken,Value: p.curToken.Literal,}, nil// ...}
}// parseFunctionCall 解析函数调用
func (p *Parser) parseFunctionCall() (ast.Expression, error) {fn := &ast.FunctionCall{Token:  p.curToken,Name:   p.curToken.Literal,Params: []ast.Expression{},}// 检查下一个token是否为左括号if !p.expectPeek(lexer.LPAREN) {return nil, fmt.Errorf("expected ( after function name")}// 如果是右括号,说明没有参数if p.peekTokenIs(lexer.RPAREN) {p.nextToken()return fn, nil}// 解析参数列表for !p.peekTokenIs(lexer.RPAREN) {p.nextToken()param, err := p.parseExpression()if err != nil {return nil, err}fn.Params = append(fn.Params, param)if p.peekTokenIs(lexer.COMMA) {p.nextToken()} else if !p.peekTokenIs(lexer.RPAREN) {return nil, fmt.Errorf("expected comma or closing parenthesis in function call")}}if !p.expectPeek(lexer.RPAREN) {return nil, fmt.Errorf("Missing closing parenthesis for function call")}return fn, nil
}
  1. AST 定义

在 internal/ast/ast.go 中,我们添加了 FunctionCall 类型来表示函数调用:

// FunctionCall 表示函数调用
type FunctionCall struct {Token  lexer.TokenName   stringParams []Expression
}func (fc *FunctionCall) expressionNode()      {}
func (fc *FunctionCall) TokenLiteral() string { return fc.Token.Literal }
  1. 存储引擎实现

在 internal/storage/memory.go 中,Select 方法被增强以支持聚合函数:

// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*ast.Results, error) {table, exists := b.tables[stmt.TableName]if !exists {return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)}results := &ast.Results{Columns: make([]ast.ResultColumn, 0),Rows:    make([][]ast.Cell, 0),}// 检查是否为聚合函数查询isAggregation := falsevar aggregateFunc *ast.FunctionCall// 处理select列表if len(stmt.Fields) == 1 {// 检查是否为 SELECT *if _, ok := stmt.Fields[0].(*ast.StarExpression); ok {// SELECT *for _, col := range table.Columns {results.Columns = append(results.Columns, ast.ResultColumn{Name: col.Name,Type: col.Type,})}} else if fn, ok := stmt.Fields[0].(*ast.FunctionCall); ok {// 处理函数调用isAggregation = trueaggregateFunc = fnresults.Columns = append(results.Columns, ast.ResultColumn{Name: fn.Name,Type: "FUNCTION",})}// ... 其他情况}// ... 其他情况// 如果是聚合函数查询,直接计算结果if isAggregation {// 处理WHERE子句filteredRows := make([][]ast.Cell, 0)for _, row := range table.Rows {if stmt.Where != nil {match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)if err != nil {return nil, err}if !match {continue}}filteredRows = append(filteredRows, row)}functionResult := calculateFunctionResults(aggregateFunc, table, filteredRows)results.Rows = [][]ast.Cell{functionResult}return results, nil}// ... 非聚合函数的处理
}

每个聚合函数都有对应的计算方法:

// calculateFunctionResults 计算函数结果
func calculateFunctionResults(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {// 根据函数类型计算结果switch strings.ToUpper(fn.Name) {case "COUNT":return calculateCount(fn, table, rows)case "SUM":return calculateSum(fn, table, rows)case "AVG":return calculateAvg(fn, table, rows)case "MAX":return calculateMax(fn, table, rows)case "MIN":return calculateMin(fn, table, rows)default:return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: Unknown function '%s'", fn.Name)}}}
}// calculateCount 计算COUNT函数结果
func calculateCount(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {return []ast.Cell{{Type: ast.CellTypeInt, IntValue: int32(len(rows))}}
}// calculateSum 计算SUM函数结果
func calculateSum(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {// 处理 SUM(column) 情况if len(fn.Params) != 1 {return []ast.Cell{{Type: ast.CellTypeText, TextValue: "ERROR: SUM function requires exactly one parameter"}}}var columnName string// 检查参数类型switch param := fn.Params[0].(type) {case *ast.Identifier:columnName = param.Valuedefault:return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: SUM function requires a column name, got %T", param)}}}// 查找列索引colIndex := -1for i, col := range table.Columns {if col.Name == columnName {colIndex = ibreak}}if colIndex == -1 {return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: Unknown column '%s'", columnName)}}}// 计算SUM值var sumInt int32 = 0var sumFloat float32 = 0.0hasFloat := falsefor _, row := range rows {cell := row[colIndex]switch cell.Type {case ast.CellTypeInt:sumInt += cell.IntValuecase ast.CellTypeFloat:// 如果之前有整数,需要转换为浮点数if !hasFloat {sumFloat = float32(sumInt)hasFloat = true}sumFloat += cell.FloatValue}}// 返回结果if hasFloat {return []ast.Cell{{Type: ast.CellTypeFloat, FloatValue: sumFloat}}}return []ast.Cell{{Type: ast.CellTypeInt, IntValue: sumInt}}
}
// ... 其他聚合函数的实现

测试

测试SQL:

-- 创建测试表
CREATE TABLE users (id INT PRIMARY KEY, name TEXT, age INT);-- 插入测试数据
INSERT INTO users VALUES (1, 'Alice', 20);
INSERT INTO users VALUES (2, 'Bob', 25);
INSERT INTO users VALUES (3, 'Charlie', 30);-- 使用聚合函数
SELECT COUNT(*) FROM users;
SELECT SUM(age) FROM users;
SELECT AVG(age) FROM users;
SELECT MAX(age) FROM users;
SELECT MIN(age) FROM users;-- 带WHERE条件的聚合函数
SELECT COUNT(*) FROM users WHERE age > 25;
SELECT SUM(age) FROM users WHERE age >= 25;

效果:
在这里插入图片描述

3. group by 实现

设计思路

1.语法解析:

首先在internal/lexer/token.go中新增group by关键字

在这里插入图片描述然后在internal/lexer/lexer.go词法分析器的lookupIdentifier方法中新增对group by关键字的识别
在这里插入图片描述
接下来在internal/parser/parser.go词法分析器中的parseSelectStatement方法中添加 GROUP 和 BY 关键字的解析,将其解析并封装为ast的一部分
在这里插入图片描述
在 internal/ast/ast.go 中添加 GroupBy 字段到 SelectStatement 结构体
在这里插入图片描述
2. 执行引擎:

首先在internal/storage/memory.go存储引擎中的Select方法实现对分组逻辑的调用
在这里插入图片描述
接着selectWithGroupBy方法,实现底层分组原理,按指定列对数据进行分组
在这里插入图片描述

在这里插入图片描述
3. internal/storage/memory.go中的selectWithGroupBy对聚合函数进行处理,确保查询结果列是聚合函数列或者分组列
在这里插入图片描述

代码实现

  1. 在词法分析器中添加新的关键字
// internal/lexer/token.go
const (// ... 其他关键字GROUP   TokenType = "GROUP"BY      TokenType = "BY"
)// internal/lexer/lexer.go
func (l *Lexer) lookupIdentifier(ident string) TokenType {switch strings.ToUpper(ident) {// ... 其他关键字case "GROUP":return GROUPcase "BY":return BYdefault:return IDENT}
}
  1. 在 AST 中添加新的结构体以支持 GROUP BY
// internal/ast/ast.go// SelectStatement 表示SELECT语句
type SelectStatement struct {Token     lexer.TokenFields    []ExpressionTableName stringWhere     ExpressionGroupBy   []Expression    // 添加 GroupBy 字段
}
  1. 在语法分析器中添加对 GROUP BY 子句的解析
// internal/parser/parser.go// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {stmt := &ast.SelectStatement{Token: p.curToken}// ... 解析选择列表和 FROM 子句 ...// 解析WHERE子句if p.peekTokenIs(lexer.WHERE) {p.nextToken()whereExpr, err := p.parseWhereClause()if err != nil {return nil, err}stmt.Where = whereExpr}// 解析GROUP BY子句if p.peekTokenIs(lexer.GROUP) {p.nextToken() // 跳过 GROUPif !p.expectPeek(lexer.BY) {return nil, fmt.Errorf("expected BY after GROUP")}// 解析GROUP BY字段列表for {p.nextToken()if !p.curTokenIs(lexer.IDENT) {return nil, fmt.Errorf("expected identifier in GROUP BY clause")}expr := &ast.Identifier{Token: p.curToken,Value: p.curToken.Literal,}stmt.GroupBy = append(stmt.GroupBy, expr)if !p.peekTokenIs(lexer.COMMA) {break}p.nextToken() // 跳过逗号}}return stmt, nil
}
  1. 在存储引擎中实现 GROUP BY 的执行逻辑
// internal/storage/memory.go// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*Results, error) {table, exists := b.tables[stmt.TableName]if !exists {return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)}// 如果有 GROUP BY 子句if len(stmt.GroupBy) > 0 {return b.selectWithGroupBy(stmt, table)}// ... 原有的查询逻辑 ...
}// selectWithGroupBy 处理带有 GROUP BY 的查询
func (b *MemoryBackend) selectWithGroupBy(stmt *ast.SelectStatement, table *Table) (*Results, error) {results := &Results{Columns: make([]ResultColumn, 0),Rows:    make([][]Cell, 0),}// 验证 GROUP BY 字段存在于表中groupByIndices := make([]int, len(stmt.GroupBy))for i, expr := range stmt.GroupBy {if identifier, ok := expr.(*ast.Identifier); ok {found := falsefor j, col := range table.Columns {if col.Name == identifier.Value {groupByIndices[i] = jfound = truebreak}}if !found {return nil, fmt.Errorf("Unknown column '%s' in 'group statement'", identifier.Value)}} else {return nil, fmt.Errorf("GROUP BY only supports column names")}}// 构建结果列for _, expr := range stmt.Fields {switch e := expr.(type) {case *ast.Identifier:found := falsefor _, col := range table.Columns {if col.Name == e.Value {results.Columns = append(results.Columns, ResultColumn{Name: col.Name,Type: col.Type,})found = truebreak}}if !found {return nil, fmt.Errorf("Unknown column '%s' in 'field list'", e.Value)}case *ast.FunctionCall:results.Columns = append(results.Columns, ResultColumn{Name: e.Name,Type: "FUNCTION",})case *ast.StarExpression:for _, col := range table.Columns {results.Columns = append(results.Columns, ResultColumn{Name: col.Name,Type: col.Type,})}default:return nil, fmt.Errorf("Unsupported select expression type")}}// 处理WHERE子句filteredRows := make([][]Cell, 0)for _, row := range table.Rows {if stmt.Where != nil {match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)if err != nil {return nil, err}if !match {continue}}filteredRows = append(filteredRows, row)}// 按 GROUP BY 字段分组groups := make(map[string][][]Cell)for _, row := range filteredRows {// 构建分组键groupKey := ""for _, idx := range groupByIndices {groupKey += row[idx].String() + "|"}// 将行添加到对应的组中groups[groupKey] = append(groups[groupKey], row)}// 为每个组计算结果for _, groupRows := range groups {if len(groupRows) == 0 {continue}resultRow := make([]Cell, len(results.Columns))colIndex := 0// 处理非聚合字段(GROUP BY 字段)for _, expr := range stmt.Fields {if identifier, ok := expr.(*ast.Identifier); ok {// 检查是否为 GROUP BY 字段isGroupByField := falsefor _, groupByExpr := range stmt.GroupBy {if groupByIdent, ok := groupByExpr.(*ast.Identifier); ok {if groupByIdent.Value == identifier.Value {isGroupByField = truebreak}}}if isGroupByField {// 对于 GROUP BY 字段,取第一个值(所有行应该相同)for k, tableCol := range table.Columns {if tableCol.Name == identifier.Value {resultRow[colIndex] = groupRows[0][k]break}}}colIndex++}}// 处理聚合函数for i, expr := range stmt.Fields {if fn, ok := expr.(*ast.FunctionCall); ok {functionResult := calculateFunctionResults(fn, table, groupRows)resultRow[i] = functionResult[0]}}results.Rows = append(results.Rows, resultRow)}return results, nil
}

测试

测试SQL:

CREATE TABLE sales (id INT PRIMARY KEY, product TEXT, category TEXT, amount FLOAT);
INSERT INTO sales VALUES (1, 'Apple', 'Fruit', 10.5);
INSERT INTO sales VALUES (2, 'Banana', 'Fruit', 8.0);
INSERT INTO sales VALUES (3, 'Carrot', 'Vegetable', 5.2);
INSERT INTO sales VALUES (4, 'Broccoli', 'Vegetable', 7.3);
INSERT INTO sales VALUES (5, 'Orange', 'Fruit', 9.8);
SELECT category, COUNT(*) FROM sales GROUP BY category;
SELECT category, SUM(amount) FROM sales GROUP BY category;
SELECT category, AVG(amount) FROM sales GROUP BY category;

效果:
在这里插入图片描述

4. order by 实现

设计思路

与group by实现基本一致

1.语法解析:

在词法分析器中添加 ORDER、BY、ASC 和 DESC 关键字

  • internal/lexer/token.go:
    在这里插入图片描述
  • internal/lexer/lexer.go的lookupIdentifier方法:
    在这里插入图片描述

在语法分析器中解析 ORDER BY 子句:
在这里插入图片描述

在 internal/ast/ast.go中添加 OrderBy 字段到 SelectStatement 结构体
在这里插入图片描述

2.执行引擎:

在internal/storage/memory.go存储引擎的Select方法中实现对order by的解析调用:
在这里插入图片描述
同时实现排序逻辑,使用 Go 标准库的 sort.Slice 进行排序同时实现自定义比较函数以支持不同数据类型的比较:
在这里插入图片描述

在这里插入图片描述

代码实现

  1. 在词法分析器中添加新的关键字
// internal/lexer/token.go
const (// ... 其他关键字ORDER   TokenType = "ORDER"ASC     TokenType = "ASC"DESC    TokenType = "DESC"
)// internal/lexer/lexer.go
func (l *Lexer) lookupIdentifier(ident string) TokenType {switch strings.ToUpper(ident) {// ... 其他关键字case "ORDER":return ORDERcase "ASC":return ASCcase "DESC":return DESCdefault:return IDENT}
}
  1. 在 AST 中添加新的结构体以支持 ORDER BY
// internal/ast/ast.go// SelectStatement 表示SELECT语句
type SelectStatement struct {Token     lexer.TokenFields    []ExpressionTableName stringWhere     ExpressionOrderBy   []OrderByClause // 添加 OrderBy 字段
}// OrderByClause 表示 ORDER BY 子句中的排序项
type OrderByClause struct {Expression ExpressionDirection  string // "ASC" 或 "DESC"
}
  1. 在语法分析器中添加对 ORDER BY 子句的解析
// internal/parser/parser.go// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {stmt := &ast.SelectStatement{Token: p.curToken}// ... 解析选择列表、FROM 子句和 WHERE 子句 ...// 解析GROUP BY子句(如果有的话)if p.peekTokenIs(lexer.GROUP) {// ... GROUP BY 解析逻辑 ...}// 解析ORDER BY子句if p.peekTokenIs(lexer.ORDER) {orderExprs, err := p.parseOrderByClause()if err != nil {return nil, err}stmt.OrderBy = orderExprs}return stmt, nil
}// parseOrderByClause 解析ORDER BY子句
func (p *Parser) parseOrderByClause() ([]ast.OrderByClause, error) {// 跳过 ORDER 关键字if !p.expectPeek(lexer.ORDER) {return nil, fmt.Errorf("expected ORDER keyword")}// 跳过 BY 关键字if !p.expectPeek(lexer.BY) {return nil, fmt.Errorf("expected BY keyword")}var orderExprs []ast.OrderByClausefor {p.nextToken()// 解析表达式(列名)if !p.curTokenIs(lexer.IDENT) {return nil, fmt.Errorf("expected identifier in ORDER BY clause")}expr := &ast.Identifier{Token: p.curToken,Value: p.curToken.Literal,}orderClause := ast.OrderByClause{Expression: expr,Direction:  "ASC", // 默认升序}// 检查是否有 ASC 或 DESCif p.peekTokenIs(lexer.ASC) || p.peekTokenIs(lexer.DESC) {p.nextToken()orderClause.Direction = p.curToken.Literal}orderExprs = append(orderExprs, orderClause)// 如果没有逗号,说明结束了if !p.peekTokenIs(lexer.COMMA) {break}p.nextToken() // 跳过逗号}return orderExprs, nil
}
  1. 在存储引擎中实现 ORDER BY 的执行逻辑
// internal/storage/memory.go// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*Results, error) {// ... 原有的查询逻辑 ...// 处理 ORDER BYif len(stmt.OrderBy) > 0 {var err errorresults.Rows, err = b.orderBy(results.Rows, results.Columns, stmt.OrderBy, table.Columns)if err != nil {return nil, err}}return results, nil
}// orderBy 根据 ORDER BY 子句对结果进行排序
func (b *MemoryBackend) orderBy(rows [][]Cell, resultCols []ResultColumn, orderBy []ast.OrderByClause, tableCols []ast.ColumnDefinition) ([][]Cell, error) {// 创建列名到索引的映射colIndexMap := make(map[string]int)for i, col := range resultCols {colIndexMap[col.Name] = i}// 创建排序键的索引和方向type sortKey struct {index     intdirection string}var sortKeys []sortKeyfor _, ob := range orderBy {identifier, ok := ob.Expression.(*ast.Identifier)if !ok {return nil, fmt.Errorf("ORDER BY only supports column names")}index, exists := colIndexMap[identifier.Value]if !exists {return nil, fmt.Errorf("Unknown column '%s' in 'order clause'", identifier.Value)}sortKeys = append(sortKeys, sortKey{index:     index,direction: ob.Direction,})}// 使用 sort.Slice 进行排序sort.Slice(rows, func(i, j int) bool {for _, key := range sortKeys {left := rows[i][key.index]right := rows[j][key.index]// 比较两个值result, err := compareValues(left, right, "<")if err != nil {// 如果比较出错,保持原有顺序return false}if result {// 如果是升序,返回 true// 如果是降序,返回 falsereturn key.direction == "ASC"} else {// 检查是否相等equal, _ := compareValues(left, right, "=")if !equal {// 如果是降序,返回 true// 如果是升序,返回 falsereturn key.direction == "DESC"}// 如果相等,继续比较下一个排序键}}// 所有键都相等,保持原有顺序return false})return rows, nil
}

测试

测试SQL:

CREATE TABLE sales (id INT PRIMARY KEY, product TEXT, category TEXT, amount FLOAT);
INSERT INTO sales VALUES (1, 'Apple', 'Fruit', 10.5);
INSERT INTO sales VALUES (2, 'Banana', 'Fruit', 8.0);
INSERT INTO sales VALUES (3, 'Carrot', 'Vegetable', 5.2);
INSERT INTO sales VALUES (4, 'Broccoli', 'Vegetable', 7.3);
INSERT INTO sales VALUES (5, 'Orange', 'Fruit', 9.8);SELECT * FROM sales ORDER BY amount;
SELECT * FROM sales ORDER BY amount DESC;
SELECT * FROM sales ORDER BY category, amount DESC;

效果:
在这里插入图片描述


文章转载自:

http://SsHZJHoP.cbpkr.cn
http://Rku4RYjK.cbpkr.cn
http://1GxahC9X.cbpkr.cn
http://oHgr8zMx.cbpkr.cn
http://9COVqR9R.cbpkr.cn
http://2LjyKXyR.cbpkr.cn
http://Kal7oKdX.cbpkr.cn
http://cIbR7cNm.cbpkr.cn
http://Wzx0dCjr.cbpkr.cn
http://OpwrUJUs.cbpkr.cn
http://xpy3bqcZ.cbpkr.cn
http://YPqA7uwN.cbpkr.cn
http://eTI4zp9h.cbpkr.cn
http://7ISP1rZD.cbpkr.cn
http://gZQcJIus.cbpkr.cn
http://IthKk7Rx.cbpkr.cn
http://ScEKTdvy.cbpkr.cn
http://7cOUa4Fj.cbpkr.cn
http://qiRIHcmo.cbpkr.cn
http://nURzYJyX.cbpkr.cn
http://IWBWhSje.cbpkr.cn
http://kdaxPlPh.cbpkr.cn
http://B52liggU.cbpkr.cn
http://4ePbP4mg.cbpkr.cn
http://0FpML99E.cbpkr.cn
http://ods5hRBH.cbpkr.cn
http://s67rbCG4.cbpkr.cn
http://uA9COIfj.cbpkr.cn
http://SOuodCKm.cbpkr.cn
http://byU24vik.cbpkr.cn
http://www.dtcms.com/a/372581.html

相关文章:

  • 【74LS112+08同步十六进制和九进制0-8、8-0显示】2022-12-3
  • C++在控制台打印不同颜色的文本:让日志输出更炫酷
  • ego(3)---根据关键点求解B样条控制点
  • AutoHotkey下载安装并运行第一个脚本
  • ASP4644S电源芯片在商业卫星载荷通讯项目中的成本效益分析
  • HTTPS优化简单总结
  • 磁共振成像原理(理论):信号产生和探测(3)
  • 写程序or打游戏(组合计数)
  • 生成式AI基石之一:变分自编码器(VAE)详解:从架构到数学的深度指南
  • VXLAN集中式网关实验案例
  • 培训学校押金原路退回-企业自动运营——东方仙盟
  • Ubuntu系统的备份和恢复方法
  • 【已解决】Linux中程序脚本可以手动执行成功,但加在rc.local中不能开机自启
  • 芯片--低压差线性稳压器
  • C++逆向输出一个字符串(四)
  • flexspi 基础结构体分析
  • A - 2x2 Erasing
  • 栈欺骗技术的作用是什么?
  • 细说分布式ID
  • nginx自动剔除与恢复
  • tmi8150B控制ir_cut
  • 【期末复习】嵌入式——S5PV210开发板
  • 基于brpc的轻量级服务注册中心设计与实现
  • 作用域報錯
  • 代码随想录学习摘抄day7(二叉树11-21)
  • 固态硬盘——M.2接口技术
  • 数字化浪潮下,传统加工厂如何智能化转型?
  • Miniflux – RSS 订阅
  • Nginx主配置文件
  • 架构进阶——解读121页IT规划咨询项目规划报告【附全文阅读】