当前位置: 首页 > news >正文

go-zero自动生成repository文件和测试用例

文章目录

  • repository的作用
  • 自动生成repository文件
    • repo模板文件
    • repo_test模板文件
    • 生成结果
    • 运行测试用例

repository的作用

在软件开发中,尤其是在采用分层架构或者领域驱动设计(DDD)的项目里,repository(仓库)是一个关键概念,它起到数据访问层和业务逻辑层之间的桥梁作用,负责处理数据的持久化与检索,让业务逻辑层无需直接与数据库或其他数据存储交互。例如下面代码中的 UserRepo 中的几个方法示例,展示了它是如何进行数据操作的:

// 获取用户信息
func (r *UserRepo) GetUserInfo(id int64) (*model.User, error) {
    m := r.svcCtx.Model.User
    res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Where(m.IsDeleted.Eq(0)).First()
    if err != nil {
        logx.WithContext(r.ctx).Errorf("Failed to get data by id: %v", err)
        return nil, err
    }
    return res, nil
}

// 创建用户
func (r *UserRepo) CreateUser(models *model.User) (*model.User, error) {
    m := r.svcCtx.Model.User
    err := m.WithContext(r.ctx).Debug().Create(models)
    if err != nil {
        logx.WithContext(r.ctx).Errorf("Failed to insert data: %v", err)
        return nil, err
    }
    return models, nil
}

以上代码中, UserRepo 就是一个典型的仓库(repository)示例。它具备以下几个方面的作用:

  • 数据访问封装:UserRepo 把对 User 模型的数据库操作封装起来,像查询、插入、更新和删除等操作。这样,业务逻辑层就能通过调用 UserRepo 的方法来完成数据操作,而无需关心具体的数据库交互细节。
  • 数据持久化与检索:UserRepo 提供了一系列方法,用于从数据库中获取用户信息、统计用户数量、创建新用户、更新用户信息以及删除用户等操作。这些方法负责将业务逻辑层的请求转化为数据库查询或更新操作。
  • 上下文管理:UserRepo 借助 context.Context 来管理请求的上下文,保证数据库操作能在请求的生命周期内完成,并且可以处理超时和取消操作。

简单来说,repository 是一种设计模式,它把数据访问逻辑封装起来,让业务逻辑层和数据访问层解耦,从而提高代码的可维护性和可测试性。UserRepo 就是一个实现了该模式的具体类,它提供了对 User 模型的各种数据操作方法。

自动生成repository文件

在之前的这篇文章《go-zero框架基本配置和错误码封装》中,可以使用GEN 自动生成 GORM 模型结构体文件和查询方法。接下来,在此基础上,我们来通过一个脚本可以自动生成所需数据表的repository文件。

新增一个 generate_repo_files.go文件,写入如下脚本:

package main

const repoDir = "./internal/repo/mysql"

// toCamelCase 将下划线分隔的字符串转换为驼峰命名
func toCamelCase(s string) string {
	parts := strings.Split(s, "_")
	var result strings.Builder
	for _, part := range parts {
		if len(part) > 0 {
			result.WriteString(strings.ToUpper(part[:1]) + part[1:])
		}
	}
	return result.String()
}

func generateRepoFile(baseName string) error {
	// 生成文件名
	fileName := strings.ToLower(baseName) + "_repo.go"
	filePath := filepath.Join(repoDir, fileName)

	// 检查文件是否存在
	if _, err := os.Stat(filePath); err == nil {
		fmt.Printf("File %s already exists, skipping generation.\n", filePath)
		return nil
	}

	// 创建文件
	file, err := os.Create(filePath)
	if err != nil {
		return fmt.Errorf("failed to create file %s: %v", filePath, err)
	}
	defer func(file *os.File) {
		err := file.Close()
		if err != nil {
			fmt.Printf("Failed to close file %s: %v\n", filePath, err)
		}
	}(file)

	// 生成 repo 名称, 将数据表的下划线转为驼峰写法
	repoName := toCamelCase(baseName)

	// 使用模板生成文件内容
	// 读取模板文件内容
	templateFilePath := "template_files/repo_template.txt"
	_, fileDir, _, _ := runtime.Caller(0)
	templateData, err := os.ReadFile(filepath.Join(fileDir, "../", templateFilePath))
	if err != nil {
		log.Fatalf("repo模板文件读取失败: %s", err)
	}

	// 将字节切片转换为字符串并打印
	repoTemplate := string(templateData)

	tmpl, err := template.New("repo").Parse(repoTemplate)
	if err != nil {
		return fmt.Errorf("failed to parse template: %v", err)
	}

	data := struct {
		RepoName string
	}{
		RepoName: repoName,
	}

	if err := tmpl.Execute(file, data); err != nil {
		return fmt.Errorf("failed to execute template: %v", err)
	}

	fmt.Printf("Successfully generated file %s\n", filePath)
	return nil
}

func main() {
	// 初始化go-zero的配置
	var c config.Config
	configFile := config.GetConfigFile("") //调用自定义的GetConfigFile方法,读取当前配置的env信息
	conf.MustLoad(configFile, &c)

	tablePrefix := "" // 表名前缀

	// 连接数据库
	db, _ := gorm.Open(mysql.Open(c.Mysql.DataSource), &gorm.Config{
		NamingStrategy: schema.NamingStrategy{
			TablePrefix:   tablePrefix, // 表名前缀
			SingularTable: true,        // 使用单数表名,启用该选项,会区分 user 和 users 表为两个不同的数据表
		},
	})

	//获取全部表名
	var tableNames []string
	db.Raw("SHOW TABLES").Scan(&tableNames)

	for _, tableName := range tableNames {
		// 生成 repo 文件
		baseName := strings.TrimPrefix(tableName, tablePrefix)
		if err := generateRepoFile(baseName); err != nil {
			fmt.Printf("Error generating repo file for table %s: %v\n", tableName, err)
		}
		if err := generateRepoTestFile(baseName); err != nil {
			fmt.Printf("Error generating repo test file for table %s: %v\n", tableName, err)
		}
	}
}

以上代码是通过读取模板文件template_files/repo_template.txt 将相关的变量替换为指定的数据表名,然后生成统一的增删改查方法。

repo模板文件

接下来定义这个模板文件template_files/repo_template.txt

package mysql

type {{.RepoName}}Repo struct {
	logx.Logger
	ctx    context.Context
	svcCtx *svc.ServiceContext
}

func New{{.RepoName}}Repo(ctx context.Context, sCtx *svc.ServiceContext) *{{.RepoName}}Repo {
	return &{{.RepoName}}Repo{
		Logger: logx.WithContext(ctx),
		ctx:    ctx,
		svcCtx: sCtx,
	}
}

// 查询数据详情
func (r *{{.RepoName}}Repo) Get{{.RepoName}}Info(id int64) (*model.{{.RepoName}}, error) {
	m := r.svcCtx.Model.{{.RepoName}}
	res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Where(m.IsDeleted.Eq(0)).First()
	if err != nil {
		logx.WithContext(r.ctx).Errorf("Failed to get data by id: %v", err)
		return nil, err
	}
	return res, nil
}

// 查询数据总数
func (r *{{.RepoName}}Repo) Get{{.RepoName}}Count(queryEqualConditions map[string]interface{}, queryMoreConditions map[string]interface{}) (int64, error) {
	m := r.svcCtx.Model.{{.RepoName}}
	queryEqualConditions["is_deleted"] = 0
	querys := m.WithContext(r.ctx).Debug()

	//基础查询条件
	querys = querys.Where(field.Attrs(queryEqualConditions))
	//更多查询条件
	if len(queryMoreConditions) > 0 {
		//todo your logic
	}
	return querys.Count()
}

// 查询数据列表
func (r *{{.RepoName}}Repo) Get{{.RepoName}}ListByConditions(queryEqualConditions map[string]interface{}, queryMoreConditions map[string]interface{}, page, pageSize int) ([]*model.{{.RepoName}}, int64, error) {
	offset := (page - 1) * pageSize
	limit := pageSize
	m := r.svcCtx.Model.{{.RepoName}}
	queryEqualConditions["is_deleted"] = 0
	querys := m.WithContext(r.ctx).Debug()

	//基础查询条件
	querys = querys.Where(field.Attrs(queryEqualConditions))
	//更多查询条件
	if len(queryMoreConditions) > 0 {
		//todo your logic
	}
	return querys.Order(m.ID.Desc()).FindByPage(offset, limit)
}

// 查询一批id和创建时间的map
func (r *{{.RepoName}}Repo) Get{{.RepoName}}CreateTimeMapByIds(Ids []int64) (map[int64]string, error) {
	if len(Ids) == 0 {
		return nil, nil
	}
	Ids = utils.RemoveDuplicatesInt64Slice(Ids)

	m := r.svcCtx.Model.{{.RepoName}}
	res, err := m.WithContext(r.ctx).Debug().Where(
		m.ID.In(Ids...),
		m.IsDeleted.Eq(0),
	).Find()
	if err != nil {
		logx.WithContext(r.ctx).Errorf("Failed to get create_time by ids: %v", err)
		return nil, err
	}

	mapData := make(map[int64]string, len(res))
	for _, v := range res {
		mapData[v.ID] = (v.CreateTime).Format("2006-01-02 15:04:05")
	}
	return mapData, nil
}

// 创建数据
func (r *{{.RepoName}}Repo) Create{{.RepoName}}(models *model.{{.RepoName}}) (*model.{{.RepoName}}, error) {
	m := r.svcCtx.Model.{{.RepoName}}
	err := m.WithContext(r.ctx).Debug().Create(models)
	if err != nil {
		logx.WithContext(r.ctx).Errorf("Failed to insert data: %v", err)
		return nil, err
	}
	return models, nil
}

// 更新数据
func (r *{{.RepoName}}Repo) Update{{.RepoName}}(models *model.{{.RepoName}}, id int64) (gen.ResultInfo, error) {
	m := r.svcCtx.Model.{{.RepoName}}
	res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Updates(models)
	if err != nil {
		logx.WithContext(r.ctx).Errorf("Failed to update data: %v", err)
		return res, err
	}
	return res, nil
}

// (软)删除数据
func (r *{{.RepoName}}Repo) Delete{{.RepoName}}(id int64) (gen.ResultInfo, error) {
	m := r.svcCtx.Model.{{.RepoName}}
	res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Updates(&model.{{.RepoName}}{
		IsDeleted: 1,
	})
	if err != nil {
		logx.WithContext(r.ctx).Errorf("Failed to soft delete data: %v", err)
		return res, err
	}
	return res, nil
}

/*
// (硬)删除数据
func (r *{{.RepoName}}Repo) ForeverDelete{{.RepoName}}(id int64) (gen.ResultInfo, error) {
	m := r.svcCtx.Model.{{.RepoName}}
	res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Delete()
	if err != nil {
		logx.WithContext(r.ctx).Errorf("Failed to hard delete data: %v", err)
		return res, err
	}
	return res, nil
}
*/

这个模板文件可以通过提前手动写好一个示例的go文件的常用的方法,然后保存为一个txt文件,把示例的数据表名改为{{.RepoName}} 用来替换其他的表名。

repo_test模板文件

同样的,可以提前写好一份针对以上repository各个方法的测试用例,然后替换为变量标识符,可以写出如下的测试用例的模板文件:template_files/repo_test_template.txt

package mysql_test

func TestGet{{.RepoName}}Info(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	info, err := repo.Get{{.RepoName}}Info(1)
	t.Log("info: ", utils.EchoJson(info))
	t.Log("error: ", err)
	assert.NoError(t, err)
}

func TestGet{{.RepoName}}Count(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	condsEq := make(map[string]interface{})
	condsMore := make(map[string]interface{})
	count, err := repo.Get{{.RepoName}}Count(condsEq, condsMore)
	t.Log("count: ", count)
	t.Log("error: ", err)
	assert.NoError(t, err)
}

func TestGet{{.RepoName}}ListByConditions(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	condsEq := make(map[string]interface{})
	condsMore := make(map[string]interface{})
	list, count, err := repo.Get{{.RepoName}}ListByConditions(condsEq, condsMore, 1, 1)
	t.Log("list: ", utils.EchoJson(list))
	t.Log("count: ", count)
	t.Log("error: ", err)
	assert.NoError(t, err)
}

func TestGet{{.RepoName}}CreateTimeMapByIds(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	ids := []int64{1, 2, 3}
	createTimeMap, err := repo.Get{{.RepoName}}CreateTimeMapByIds(ids)
	t.Log("createTimeMap: ", utils.EchoJson(createTimeMap))
	t.Log("error: ", err)
	assert.NoError(t, err)
}

func TestCreate{{.RepoName}}(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	resp, err := repo.Create{{.RepoName}}(&model.{{.RepoName}}{
		//todo 补充需要添加的字段
	})
	t.Log("resp: ", utils.EchoJson(resp))
	t.Log("error: ", err)
	assert.NoError(t, err)
}

func TestUpdate{{.RepoName}}(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	resp, err := repo.Update{{.RepoName}}(&model.{{.RepoName}}{
		//todo 补充需要添加的字段
	}, 1)
	t.Log("resp: ", utils.EchoJson(resp))
	t.Log("error: ", err)
	assert.NoError(t, err)
}

func TestDelete{{.RepoName}}(t *testing.T) {
	ts := utils.NewTestCtx()
	repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
	resp, err := repo.Delete{{.RepoName}}(0)
	t.Log("resp: ", utils.EchoJson(resp))
	t.Log("error: ", err)
	assert.NoError(t, err)
}

以上只是我根据自己的业务习惯整理出来的几个常用的查询数据库的方法,你也可以适当的调整增加更多的方法,然后封装到模板文件中。

生成结果

运行generate_repo_files.go 这个脚本,就可以生成设置的所有数据表的repository文件和测试用例。运行试试:

image-20250228113143813

运行测试用例

上面生成的测试用例,可以根据需要直接调试对应的repo文件。例如需要调试user_repo_test.go里面的方法,可以分别使用如下两种方式运行:

// 测试全部方法: go test user_repo_test.go -v
// 测试指定方法: go test user_repo_test.go -v -run TestGetUserInfo

image-20250409171935463

相关文章:

  • 无人机击落技术难点与要点分析!
  • 探索 OpenHarmony 开源硬件的学习路径:从入门到实战的全攻略
  • 14. git clone
  • MySQL 架构设计:数据库的“城市规划指南“
  • ubuntu18.04安装miniforge3
  • 基于Python的网络爬虫技术研究
  • OpenBayes 一周速览|1分钟生成完整音乐,DiffRhythm人声伴奏一键搞定; Stable Virtual Camera重塑3D视频创作
  • 按键消抖(用状态机实现)
  • Elasticsearch 学习规划
  • 技术优化实战解析:Stream重构与STAR法则应用指南
  • 16. git push
  • [ctfshow web入门] web33
  • Manifold-IJ 2022.1.21 版本解析:IntelliJ IDEA 的 Java 增强插件指南
  • QEMU源码全解析 —— 块设备虚拟化(17)
  • Redis - 字典(Hash)结构和 rehash 机制
  • Java NIO之Buffer
  • [wifi SAE]wpa3-personal
  • [raspberrypi 0w and respeaker 2mic]实时音频波形
  • UE5 运行时动态将玩家手部模型设置为相机的子物体
  • HTML视频和音频
  • 世界工厂网靠谱吗/优化大师手机版下载
  • app建设网站公司/百度广告怎么做
  • 帮别人做设计图的网站/2021年新闻摘抄
  • 做网站有个名字叫小廖/我想学做互联网怎么入手
  • 华泰保险公司官方网站电话/推广注册app赚钱平台
  • 网站要精细是什么意思/seo和sem