当前位置: 首页 > news >正文

简化DB操作:Golang 通用仓库模式

介绍

本代码包提供一个用于数据库操作的通用仓库 (GenericRepository),利用 Golang 和 GORM (Go ORM) 实现。该仓库设计用于简化数据库的 CRUD (创建、读取、更新、删除) 操作,支持批处理、冲突处理、分页查询等高级功能。

主要功能

  1. 创建记录 (Create): 插入单个模型实例到数据库。
  2. 创建记录(冲突时更新) (CreateOnConflict): 插入单个模型实例到数据库,如果存在冲突(例如主键冲突),则更新指定的字段。
  3. 批量创建记录 (CreateBatch): 批量插入模型实例到数据库,提高大量数据处理的效率。
  4. 批量创建记录(冲突时更新) (CreateBatchOnConflict): 批量插入模型实例,如果存在冲突,则更新指定的字段。
  5. 检索记录 (Retrieve): 根据指定参数查询数据库,并将结果填充到提供的输出变量中。
  6. 分页检索记录 (RetrievePage): 根据指定参数进行分页查询,并将结果填充到提供的输出变量中。
  7. 检索单条记录 (RetrieveOne): 根据指定参数查询单条记录。
  8. 更新记录 (Update): 更新数据库中的现有记录。
  9. 按参数更新记录 (UpdateByParams): 根据提供的参数更新符合条件的记录。
  10. 删除记录 (Delete): 删除数据库中的指定记录。
  11. 按参数删除记录 (DeleteByParams): 根据提供的参数删除符合条件的记录。
  12. 记录计数 (Count): 根据指定参数计算符合条件的记录总数。

设计理念

  • 灵活性:通过反射和接口调用,支持多种类型的模型操作。
  • 性能:支持批处理操作,减少数据库交互次数,优化性能。
  • 易用性:提供高级功能如冲突处理和分页查询,简化常见的数据库操作。

使用示例

如何在应用程序中使用这个通用的DAO层:

package main

import (
    "context"
    "log"

    "your_project/dao" // 确保此路径与您的实际项目结构匹配
    "your_project/models" // 确保此路径与您的实际项目结构匹配
    repository "your_project/common" // 确保此路径与您的实际项目结构匹配
    "gorm.io/gorm"
)

func main() {
    // 初始化数据库连接
    db := dao.InitDB()
    sqlDB, err := db.DB()
    if err != nil {
        log.Fatal("Error getting underlying sql.DB:", err)
    }
    defer sqlDB.Close() // 确保在函数结束时关闭数据库连接

    // 创建GenericRepository实例
    repo := repository.NewGenericRepository(db, &models.User{})

    // 创建一个新用户
    newUser := models.User{Name: "John Doe", Email: "john@example.com"}
    err = repo.Create(context.Background(), &newUser)
    if err != nil {
        log.Println("Error creating user:", err)
    }

    // 检索用户
    var users []models.User
    query := models.User{Name: "John Doe"}
    err = repo.Retrieve(context.Background(), &query, &users)
    if err != nil {
        log.Println("Error retrieving users:", err)
    }

    // 更新用户
    newUser.Email = "new.email@example.com"
    err = repo.Update(context.Background(), &newUser)
    if err != nil {
        log.Println("Error updating user:", err)
    }

    // 删除用户
    err = repo.Delete(context.Background(), &newUser)
    if err != nil {
        log.Println("Error deleting user:", err)
    }
}

代码解析

1. 模型定义

首先,我们定义一个用户模型(User)作为示例:

package models

import "gorm.io/gorm"

type User struct {
    gorm.Model
    Name  string `db:"name"`
    Email string `db:"email"`
}

2. 数据库初始化与迁移 (dao.go)

这部分负责创建数据库连接,并提供一个自动迁移所有模型的函数。

package dao

import (
	"log"

	"gorm.io/driver/sqlite"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
)

// InitDB 初始化数据库连接
func InitDB() *gorm.DB {
	db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
	if err != nil {
		log.Fatalf("Failed to connect database: %v", err)
	}
	// Set logger to log SQL statements
	db.Logger = logger.Default.LogMode(logger.Info)
	return db
}

// AutoMigrate 用于自动迁移提供的模型
func AutoMigrate(db *gorm.DB, models ...any) {
	if err := db.AutoMigrate(models...); err != nil {
		log.Fatalf("Failed to auto-migrate models: %v", err)
	}
}

3. 反射查询处理器 (common/processor.go)

接下来,我们创建一个反射查询处理器 ReflectiveQueryProcessor,该处理器负责根据模型的反射信息构建CRUD操作:

package repository

import (
	"reflect"
	"strings"

	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

type ReflectiveQueryProcessor struct{}

func (rqp *ReflectiveQueryProcessor) Count(db *gorm.DB, params any) (int64, error) {
	query := rqp.QueryBuilder(db, params)
	var count int64
	query = query.Model(params)
	if err := query.Count(&count).Error; err != nil {
		return 0, err
	}
	return count, nil
}

func (rqp *ReflectiveQueryProcessor) Insert(db *gorm.DB, model any) *gorm.DB {
	return db.Create(model)
}

func (rqp *ReflectiveQueryProcessor) InsertOnConflict(db *gorm.DB, model any,
	conflictKeys []string, updateColumns []string,
) *gorm.DB {
	return db.Clauses(clause.OnConflict{
		Columns:   rqp.toColumns(conflictKeys),             // 指定哪些字段冲突
		DoUpdates: clause.AssignmentColumns(updateColumns), // 指定发生冲突时更新哪些字段
	}).Create(model)
}

func (rqp *ReflectiveQueryProcessor) InsertBatch(db *gorm.DB, models any) *gorm.DB {
	return db.Create(models)
}

func (rqp *ReflectiveQueryProcessor) InsertBatchOnConflict(db *gorm.DB, models any,
	conflictKeys []string, updateColumns []string,
) *gorm.DB {
	return db.Clauses(clause.OnConflict{
		Columns:   rqp.toColumns(conflictKeys),             // 指定哪些字段冲突
		DoUpdates: clause.AssignmentColumns(updateColumns), // 指定发生冲突时更新哪些字段
	}).Create(models)
}

// Helper function to convert field names to GORM clause.Columns
func (rqp *ReflectiveQueryProcessor) toColumns(fieldNames []string) []clause.Column {
	columns := make([]clause.Column, len(fieldNames))
	for i, fieldName := range fieldNames {
		columns[i] = clause.Column{Name: fieldName}
	}
	return columns
}

func (rqp *ReflectiveQueryProcessor) Find(db *gorm.DB, params any) *gorm.DB {
	query := rqp.QueryBuilder(db, params)
	return query
}

func (rqp *ReflectiveQueryProcessor) Update(db *gorm.DB, model any) *gorm.DB {
	return db.Save(model)
}

func (rqp *ReflectiveQueryProcessor) UpdateByParams(db *gorm.DB, params any, model any) *gorm.DB {
	query := rqp.QueryBuilder(db, params)
	return query.Updates(model)
}

func (rqp *ReflectiveQueryProcessor) Remove(db *gorm.DB, model any) *gorm.DB {
	return db.Delete(model)
}

func (rqp *ReflectiveQueryProcessor) RemoveByParams(db *gorm.DB, params any, model any) *gorm.DB {
	query := rqp.QueryBuilder(db, params)
	return query.Delete(model)
}

// QueryBuilder builds a query based on the provided parameters.
func (rqp *ReflectiveQueryProcessor) QueryBuilder(db *gorm.DB, params any) *gorm.DB {
	val := reflect.ValueOf(params)
	if val.Kind() == reflect.Ptr {
		val = val.Elem()
	}

	for i := 0; i < val.NumField(); i++ {
		field := val.Type().Field(i)
		valueField := val.Field(i)

		if !valueField.IsZero() {
			dbFieldName := field.Tag.Get("db")
			if dbFieldName == "" {
				dbFieldName = strings.ToLower(field.Name)
			}
			db = db.Where(dbFieldName+" = ?", valueField.Interface())
		}
	}
	return db
}

4. 通用数据访问对象 (common/repository.go)

我们定义 GenericRepository 类,它使用 ReflectiveQueryProcessor 来执行数据库操作:

package repository

import (
	"context"
	"log"
	"reflect"

	"github.com/pkg/errors"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

const DefaultBatchSize = 1000

type GenericRepository struct {
	DB             *gorm.DB
	Model          any
	BatchSize      int
	QueryProcessor *ReflectiveQueryProcessor
}

func NewGenericRepository(db *gorm.DB, model any) *GenericRepository {
	return &GenericRepository{
		DB:             db,
		Model:          model,
		BatchSize:      DefaultBatchSize,
		QueryProcessor: &ReflectiveQueryProcessor{},
	}
}

func (gr *GenericRepository) Count(ctx context.Context, params any) (int64, error) {
	if count, err := gr.QueryProcessor.Count(gr.DB, params); err != nil {
		log.Printf("Error counting records: %v", err)
		return 0, err
	} else {
		return count, nil
	}
}

func (gr *GenericRepository) Create(ctx context.Context, model any) error {
	if err := gr.QueryProcessor.Insert(gr.DB, model).Error; err != nil {
		log.Printf("Error creating record: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) CreateOnConflict(ctx context.Context, model any,
	conflictKeys []string, updateColumns []string,
) error {
	if err := gr.QueryProcessor.InsertOnConflict(gr.DB, model, conflictKeys, updateColumns).Error; err != nil {
		log.Printf("Error creating record on conflict: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) CreateBatch(ctx context.Context, models any) error {
	processBatch := func(tx *gorm.DB) error {
		return gr.BatchProcess(tx, models, tx.Create)
	}

	return gr.DB.Transaction(processBatch)
}

func (gr *GenericRepository) CreateBatchOnConflict(ctx context.Context, models any, conflictKeys []string, updateColumns []string) error {
	processBatch := func(tx *gorm.DB) error {
		return gr.BatchProcess(tx, models, func(batch any) *gorm.DB {
			return tx.Clauses(clause.OnConflict{
				Columns:   gr.QueryProcessor.toColumns(conflictKeys),
				DoUpdates: clause.AssignmentColumns(updateColumns),
			}).Create(batch)
		})
	}

	return gr.DB.Transaction(processBatch)
}

func (gr *GenericRepository) BatchProcess(tx *gorm.DB, models any, dbFunc func(any) *gorm.DB) error {
	sliceValue := reflect.ValueOf(models)
	if sliceValue.Kind() != reflect.Slice {
		return errors.New("input data should be a slice type")
	}

	total := sliceValue.Len()
	batchSize := gr.BatchSize
	if batchSize <= 0 {
		batchSize = DefaultBatchSize
	}
	for i := 0; i < total; i += batchSize {
		end := i + batchSize
		if end > total {
			end = total
		}

		batch := sliceValue.Slice(i, end).Interface()
		if err := dbFunc(batch).Error; err != nil {
			return err
		}
	}

	return nil
}

func (gr *GenericRepository) Retrieve(ctx context.Context, params any, out any) error {
	db := gr.QueryProcessor.Find(gr.DB, params).WithContext(ctx)
	if err := db.Find(out).Error; err != nil {
		log.Printf("Error retrieving records: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) RetrievePage(ctx context.Context, params any, pageSize int, page int, out any) error {
	db := gr.QueryProcessor.Find(gr.DB, params).WithContext(ctx)
	if err := db.Offset((page - 1) * pageSize).Limit(pageSize).Find(out).Error; err != nil {
		log.Printf("Error retrieving paginated records: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) RetrieveOne(ctx context.Context, params any, out any) error {
	db := gr.QueryProcessor.Find(gr.DB, params).WithContext(ctx)
	if err := db.First(out).Error; err != nil {
		log.Printf("Error retrieving single record: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) Update(ctx context.Context, model any) error {
	if err := gr.QueryProcessor.Update(gr.DB, model).Error; err != nil {
		log.Printf("Error updating record: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) UpdateByParams(ctx context.Context, params any, model any) error {
	if err := gr.QueryProcessor.UpdateByParams(gr.DB, params, model).Error; err != nil {
		log.Printf("Error updating records by params: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) Delete(ctx context.Context, model any) error {
	if err := gr.QueryProcessor.Remove(gr.DB, model).Error; err != nil {
		log.Printf("Error deleting record: %v", err)
		return err
	}
	return nil
}

func (gr *GenericRepository) DeleteByParams(ctx context.Context, params any) error {
	if err := gr.QueryProcessor.RemoveByParams(gr.DB, params, gr.Model).Error; err != nil {
		log.Printf("Error deleting records by params: %v", err)
		return err
	}
	return nil
}

总结

在上述实现中,我们通过创建一个通用的数据访问层(DAO),提高了代码的复用性和维护性。这种结构使得对各种模型进行数据库操作变得更加直接和灵活,同时也简化了代码的管理。以下是对整个实现的总结和一些关键点的强调:

1. 模型定义的标准化

模型中的每个字段都使用了 db 标签来指定其在数据库表中对应的列名。这是一种标准化处理,使得反射机制能够正确识别和映射字段。

2. 反射查询处理器的灵活性

ReflectiveQueryProcessor 类通过反射动态处理模型,自动构建CRUD操作。这减少了为每个模型手动编写CRUD操作的需要,同时也降低了代码出错的风险。

  • 查询: 利用模型的字段值(如果非零)来构建查询条件。
  • 插入: 直接利用GORM的 Create 方法插入模型。
  • 更新: 使用GORM的 Save 方法更新模型。
  • 删除: 使用GORM的 Delete 方法删除模型。
3. 通用数据访问对象(GenericRepository)

GenericRepository 提供了一个统一的接口来处理所有模型的CRUD操作。这种设计模式(Repository模式)有助于隔离业务逻辑和数据访问代码,使得业务逻辑更加清晰,数据访问更加灵活。

4. 应用程序的简洁性

在主程序中,通过实例化 GenericRepository 并调用其方法来执行具体的数据库操作。这使得主程序不必关心数据存储的细节,而可以专注于业务逻辑。

5. 扩展性和维护性

此架构易于扩展和维护。添加新的模型或修改现有模型时,通常不需要修改数据访问层的代码。此外,如果需要替换数据库访问技术(例如从GORM迁移到其他ORM),则主要修改集中在 ReflectiveQueryProcessor 中,不会影响到业务逻辑层。

后续步骤

后续可以进一步改进和扩展当前的实现:

  • 单元测试: 为 ReflectiveQueryProcessorGenericRepository 编写单元测试,确保各种操作的正确性。
  • 错误处理: 强化错误处理机制,确保所有可能的数据库错误都能被妥善处理,并反馈给用户。
  • 性能优化: 分析和优化数据库操作的性能,特别是对于复杂的查询和大型数据集。
  • 安全性: 确保代码对SQL注入和其他潜在的安全问题有足够的防护。

通过这些实现和改进,我们可以确保应用程序的数据访问层既强大又可靠,能够支持复杂且多变的业务需求。

相关文章:

  • 【家政平台开发(33)】库存管理模块开发实战:从基础搭建到智能管控
  • 简单实现逆波兰表达式求值
  • C++_智能指针
  • 如何从零构建一个自己的 CentOS 基础镜像
  • WinForm真入门(14)——ListView控件详解
  • Work Experience
  • java相关技术总结
  • 在 openEuler 24.03 (LTS) 操作系统上添加 ollama 作为系统服务的步骤
  • 如何在Android系统上单编ko?
  • c++基础知识二
  • 剑指offer经典题目(三)
  • 基于springboot的“协同过滤算法的高考择校推荐系统”的设计与实现(源码+数据库+文档+PPT)
  • 使用模板报错:_G.unicode.len(orgline.text_stripped:gsub(“ “,““))
  • JavaScript保留小数位及提示toFixed未定义
  • 解决文件夹解压中文字符产生乱码的问题
  • SQLI漏洞公开报告分析
  • JS 数组解构
  • 无人机飞控的二次开发,视觉定位
  • 空杯见月,满杯见己
  • 全文 - MLIR Toy Tutorial Chapter 4: 使用 interfaces 开启 通用变换