go-zero自动生成repository文件和测试用例
文章目录
- repository的作用
- 自动生成repository文件
- repo模板文件
- repo_test模板文件
- 生成结果
- 运行测试用例
repository的作用
在软件开发中,尤其是在采用分层架构或者领域驱动设计(DDD)的项目里,repository(仓库)是一个关键概念,它起到数据访问层和业务逻辑层之间的桥梁作用,负责处理数据的持久化与检索,让业务逻辑层无需直接与数据库或其他数据存储交互。例如下面代码中的 UserRepo
中的几个方法示例,展示了它是如何进行数据操作的:
// 获取用户信息
func (r *UserRepo) GetUserInfo(id int64) (*model.User, error) {
m := r.svcCtx.Model.User
res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Where(m.IsDeleted.Eq(0)).First()
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to get data by id: %v", err)
return nil, err
}
return res, nil
}
// 创建用户
func (r *UserRepo) CreateUser(models *model.User) (*model.User, error) {
m := r.svcCtx.Model.User
err := m.WithContext(r.ctx).Debug().Create(models)
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to insert data: %v", err)
return nil, err
}
return models, nil
}
以上代码中, UserRepo
就是一个典型的仓库(repository)示例。它具备以下几个方面的作用:
- 数据访问封装:UserRepo 把对 User 模型的数据库操作封装起来,像查询、插入、更新和删除等操作。这样,业务逻辑层就能通过调用 UserRepo 的方法来完成数据操作,而无需关心具体的数据库交互细节。
- 数据持久化与检索:UserRepo 提供了一系列方法,用于从数据库中获取用户信息、统计用户数量、创建新用户、更新用户信息以及删除用户等操作。这些方法负责将业务逻辑层的请求转化为数据库查询或更新操作。
- 上下文管理:UserRepo 借助 context.Context 来管理请求的上下文,保证数据库操作能在请求的生命周期内完成,并且可以处理超时和取消操作。
简单来说,repository 是一种设计模式,它把数据访问逻辑封装起来,让业务逻辑层和数据访问层解耦,从而提高代码的可维护性和可测试性。UserRepo 就是一个实现了该模式的具体类,它提供了对 User 模型的各种数据操作方法。
自动生成repository文件
在之前的这篇文章《go-zero框架基本配置和错误码封装》中,可以使用GEN 自动生成 GORM 模型结构体文件和查询方法。接下来,在此基础上,我们来通过一个脚本可以自动生成所需数据表的repository文件。
新增一个 generate_repo_files.go
文件,写入如下脚本:
package main
const repoDir = "./internal/repo/mysql"
// toCamelCase 将下划线分隔的字符串转换为驼峰命名
func toCamelCase(s string) string {
parts := strings.Split(s, "_")
var result strings.Builder
for _, part := range parts {
if len(part) > 0 {
result.WriteString(strings.ToUpper(part[:1]) + part[1:])
}
}
return result.String()
}
func generateRepoFile(baseName string) error {
// 生成文件名
fileName := strings.ToLower(baseName) + "_repo.go"
filePath := filepath.Join(repoDir, fileName)
// 检查文件是否存在
if _, err := os.Stat(filePath); err == nil {
fmt.Printf("File %s already exists, skipping generation.\n", filePath)
return nil
}
// 创建文件
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file %s: %v", filePath, err)
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
fmt.Printf("Failed to close file %s: %v\n", filePath, err)
}
}(file)
// 生成 repo 名称, 将数据表的下划线转为驼峰写法
repoName := toCamelCase(baseName)
// 使用模板生成文件内容
// 读取模板文件内容
templateFilePath := "template_files/repo_template.txt"
_, fileDir, _, _ := runtime.Caller(0)
templateData, err := os.ReadFile(filepath.Join(fileDir, "../", templateFilePath))
if err != nil {
log.Fatalf("repo模板文件读取失败: %s", err)
}
// 将字节切片转换为字符串并打印
repoTemplate := string(templateData)
tmpl, err := template.New("repo").Parse(repoTemplate)
if err != nil {
return fmt.Errorf("failed to parse template: %v", err)
}
data := struct {
RepoName string
}{
RepoName: repoName,
}
if err := tmpl.Execute(file, data); err != nil {
return fmt.Errorf("failed to execute template: %v", err)
}
fmt.Printf("Successfully generated file %s\n", filePath)
return nil
}
func main() {
// 初始化go-zero的配置
var c config.Config
configFile := config.GetConfigFile("") //调用自定义的GetConfigFile方法,读取当前配置的env信息
conf.MustLoad(configFile, &c)
tablePrefix := "" // 表名前缀
// 连接数据库
db, _ := gorm.Open(mysql.Open(c.Mysql.DataSource), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: tablePrefix, // 表名前缀
SingularTable: true, // 使用单数表名,启用该选项,会区分 user 和 users 表为两个不同的数据表
},
})
//获取全部表名
var tableNames []string
db.Raw("SHOW TABLES").Scan(&tableNames)
for _, tableName := range tableNames {
// 生成 repo 文件
baseName := strings.TrimPrefix(tableName, tablePrefix)
if err := generateRepoFile(baseName); err != nil {
fmt.Printf("Error generating repo file for table %s: %v\n", tableName, err)
}
if err := generateRepoTestFile(baseName); err != nil {
fmt.Printf("Error generating repo test file for table %s: %v\n", tableName, err)
}
}
}
以上代码是通过读取模板文件template_files/repo_template.txt
将相关的变量替换为指定的数据表名,然后生成统一的增删改查方法。
repo模板文件
接下来定义这个模板文件template_files/repo_template.txt
:
package mysql
type {{.RepoName}}Repo struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func New{{.RepoName}}Repo(ctx context.Context, sCtx *svc.ServiceContext) *{{.RepoName}}Repo {
return &{{.RepoName}}Repo{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: sCtx,
}
}
// 查询数据详情
func (r *{{.RepoName}}Repo) Get{{.RepoName}}Info(id int64) (*model.{{.RepoName}}, error) {
m := r.svcCtx.Model.{{.RepoName}}
res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Where(m.IsDeleted.Eq(0)).First()
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to get data by id: %v", err)
return nil, err
}
return res, nil
}
// 查询数据总数
func (r *{{.RepoName}}Repo) Get{{.RepoName}}Count(queryEqualConditions map[string]interface{}, queryMoreConditions map[string]interface{}) (int64, error) {
m := r.svcCtx.Model.{{.RepoName}}
queryEqualConditions["is_deleted"] = 0
querys := m.WithContext(r.ctx).Debug()
//基础查询条件
querys = querys.Where(field.Attrs(queryEqualConditions))
//更多查询条件
if len(queryMoreConditions) > 0 {
//todo your logic
}
return querys.Count()
}
// 查询数据列表
func (r *{{.RepoName}}Repo) Get{{.RepoName}}ListByConditions(queryEqualConditions map[string]interface{}, queryMoreConditions map[string]interface{}, page, pageSize int) ([]*model.{{.RepoName}}, int64, error) {
offset := (page - 1) * pageSize
limit := pageSize
m := r.svcCtx.Model.{{.RepoName}}
queryEqualConditions["is_deleted"] = 0
querys := m.WithContext(r.ctx).Debug()
//基础查询条件
querys = querys.Where(field.Attrs(queryEqualConditions))
//更多查询条件
if len(queryMoreConditions) > 0 {
//todo your logic
}
return querys.Order(m.ID.Desc()).FindByPage(offset, limit)
}
// 查询一批id和创建时间的map
func (r *{{.RepoName}}Repo) Get{{.RepoName}}CreateTimeMapByIds(Ids []int64) (map[int64]string, error) {
if len(Ids) == 0 {
return nil, nil
}
Ids = utils.RemoveDuplicatesInt64Slice(Ids)
m := r.svcCtx.Model.{{.RepoName}}
res, err := m.WithContext(r.ctx).Debug().Where(
m.ID.In(Ids...),
m.IsDeleted.Eq(0),
).Find()
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to get create_time by ids: %v", err)
return nil, err
}
mapData := make(map[int64]string, len(res))
for _, v := range res {
mapData[v.ID] = (v.CreateTime).Format("2006-01-02 15:04:05")
}
return mapData, nil
}
// 创建数据
func (r *{{.RepoName}}Repo) Create{{.RepoName}}(models *model.{{.RepoName}}) (*model.{{.RepoName}}, error) {
m := r.svcCtx.Model.{{.RepoName}}
err := m.WithContext(r.ctx).Debug().Create(models)
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to insert data: %v", err)
return nil, err
}
return models, nil
}
// 更新数据
func (r *{{.RepoName}}Repo) Update{{.RepoName}}(models *model.{{.RepoName}}, id int64) (gen.ResultInfo, error) {
m := r.svcCtx.Model.{{.RepoName}}
res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Updates(models)
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to update data: %v", err)
return res, err
}
return res, nil
}
// (软)删除数据
func (r *{{.RepoName}}Repo) Delete{{.RepoName}}(id int64) (gen.ResultInfo, error) {
m := r.svcCtx.Model.{{.RepoName}}
res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Updates(&model.{{.RepoName}}{
IsDeleted: 1,
})
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to soft delete data: %v", err)
return res, err
}
return res, nil
}
/*
// (硬)删除数据
func (r *{{.RepoName}}Repo) ForeverDelete{{.RepoName}}(id int64) (gen.ResultInfo, error) {
m := r.svcCtx.Model.{{.RepoName}}
res, err := m.WithContext(r.ctx).Debug().Where(m.ID.Eq(id)).Delete()
if err != nil {
logx.WithContext(r.ctx).Errorf("Failed to hard delete data: %v", err)
return res, err
}
return res, nil
}
*/
这个模板文件可以通过提前手动写好一个示例的go文件的常用的方法,然后保存为一个txt文件,把示例的数据表名改为{{.RepoName}}
用来替换其他的表名。
repo_test模板文件
同样的,可以提前写好一份针对以上repository各个方法的测试用例,然后替换为变量标识符,可以写出如下的测试用例的模板文件:template_files/repo_test_template.txt
package mysql_test
func TestGet{{.RepoName}}Info(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
info, err := repo.Get{{.RepoName}}Info(1)
t.Log("info: ", utils.EchoJson(info))
t.Log("error: ", err)
assert.NoError(t, err)
}
func TestGet{{.RepoName}}Count(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
condsEq := make(map[string]interface{})
condsMore := make(map[string]interface{})
count, err := repo.Get{{.RepoName}}Count(condsEq, condsMore)
t.Log("count: ", count)
t.Log("error: ", err)
assert.NoError(t, err)
}
func TestGet{{.RepoName}}ListByConditions(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
condsEq := make(map[string]interface{})
condsMore := make(map[string]interface{})
list, count, err := repo.Get{{.RepoName}}ListByConditions(condsEq, condsMore, 1, 1)
t.Log("list: ", utils.EchoJson(list))
t.Log("count: ", count)
t.Log("error: ", err)
assert.NoError(t, err)
}
func TestGet{{.RepoName}}CreateTimeMapByIds(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
ids := []int64{1, 2, 3}
createTimeMap, err := repo.Get{{.RepoName}}CreateTimeMapByIds(ids)
t.Log("createTimeMap: ", utils.EchoJson(createTimeMap))
t.Log("error: ", err)
assert.NoError(t, err)
}
func TestCreate{{.RepoName}}(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
resp, err := repo.Create{{.RepoName}}(&model.{{.RepoName}}{
//todo 补充需要添加的字段
})
t.Log("resp: ", utils.EchoJson(resp))
t.Log("error: ", err)
assert.NoError(t, err)
}
func TestUpdate{{.RepoName}}(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
resp, err := repo.Update{{.RepoName}}(&model.{{.RepoName}}{
//todo 补充需要添加的字段
}, 1)
t.Log("resp: ", utils.EchoJson(resp))
t.Log("error: ", err)
assert.NoError(t, err)
}
func TestDelete{{.RepoName}}(t *testing.T) {
ts := utils.NewTestCtx()
repo := mysql.New{{.RepoName}}Repo(ts.Ctx, ts.SvcCtx)
resp, err := repo.Delete{{.RepoName}}(0)
t.Log("resp: ", utils.EchoJson(resp))
t.Log("error: ", err)
assert.NoError(t, err)
}
以上只是我根据自己的业务习惯整理出来的几个常用的查询数据库的方法,你也可以适当的调整增加更多的方法,然后封装到模板文件中。
生成结果
运行generate_repo_files.go
这个脚本,就可以生成设置的所有数据表的repository文件和测试用例。运行试试:
运行测试用例
上面生成的测试用例,可以根据需要直接调试对应的repo文件。例如需要调试user_repo_test.go
里面的方法,可以分别使用如下两种方式运行:
// 测试全部方法: go test user_repo_test.go -v
// 测试指定方法: go test user_repo_test.go -v -run TestGetUserInfo