4. 链式操作
...大约 3 分钟
1. 实现 Update、Delete和Count
1.1 子句生成
clause/clause.go, generator.go
新增更新,删除和计数功能。
func _update(vals ...any) (string, []any) {
tableName := vals[0]
m := vals[1].(map[string]any)
var (
keys []string
vars []any
)
for k, v := range m {
keys = append(keys, k+" = ?")
vars = append(vars, v)
}
return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}
func _delete(vals ...any) (string, []any) {
return fmt.Sprintf("DELETE FROM %s", vals[0]), []any{}
}
func _count(vals ...any) (string, []any) {
return _select(vals[0], []string{"count(*)"})
}
_update
:入参为表名和map
,map
存储需要修改的数据_delete
:入参为表名_count
:复用_select
1.2 Update
session/record
:
// Update updates records of table
// support map[string]any
// kv list: string, any, string, any, ...
func (s *Session) Update(kvs ...any) (int64, error) {
m, ok := kvs[0].(map[string]any)
if !ok {
m = make(map[string]any)
for i := 0; i < len(kvs); i += 2 {
m[kvs[i].(string)] = kvs[i+1]
}
}
s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
res, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return res.RowsAffected()
}
支持map
类型或者键值对列表,流程:
- 从入参中获取map,若失败则将键值对列表转换成 map
- 构建 UPDATE 语句,并执行
- 返回执行结果
1.3 Delete
func (s *Session) Delete() (int64, error) {
s.clause.Set(clause.DELETE, s.RefTable().Name)
sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
res, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return res.RowsAffected()
}
1.4 Count
func (s *Session) Count() (int64, error) {
s.clause.Set(clause.COUNT, s.RefTable().Name)
sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
row := s.Raw(sql, vars...).QueryRow()
var tmp int64
if err := row.Scan(&tmp); err != nil {
return 0, err
}
return tmp, nil
}
2. 链式调用(chain)
链式调用:某个对象调用某个方法后,将该对象的引用/指针返回,即可以继续调用该对象的其他方法,是一种简化代码的编程方式,能够使代码更简洁、易读。
应用场景:当某个对象需要一次调用多个方法来设置其属性时,非常适合改造为链式调用。
SQL 语句的构建过程包含多个步骤,并且可以进行自由组合,适合采用链式调用。
session/record.go
新增 Limit, Where, Order By :
func (s *Session) Limit(num int) *Session {
s.clause.Set(clause.LIMIT, num)
return s
}
func (s *Session) Where(desc string, args ...any) *Session {
var vars []any
s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
return s
}
func (s *Session) OrderBy(desc string) *Session {
s.clause.Set(clause.ORDERBY, desc)
return s
}
3. First 只返回一条记录
func (s *Session) First(val any) error {
dst := reflect.Indirect(reflect.ValueOf(val))
dstSlice := reflect.New(reflect.SliceOf(dst.Type())).Elem()
if err := s.Limit(1).Find(dstSlice.Addr().Interface()); err != nil {
return err
}
if dstSlice.Len() == 0 {
return errors.New("NOT FOUND")
}
dst.Set(dstSlice.Index(0))
return nil
}
- 获取传入对象的类型
- 构建切片
- 将查询结果限制为1条,并将结果写入到切片中
- 将切片中的数据拷贝到原对象中
4. 单元测试
session/record_test.go
func TestSession_Limit(t *testing.T) {
s := testInit(t)
var users []User
err := s.Limit(1).Find(&users)
if err != nil || len(users) != 1 {
t.Fatal("failed to query with limit")
}
}
func TestSession_Update(t *testing.T) {
s := testInit(t)
affected, _ := s.Where("Name = ?", "A").Update("Age", 10)
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
if affected != 1 || u.Age != 10 {
t.Fatal("failed to update")
}
}
func TestSession_Delete(t *testing.T) {
s := testInit(t)
affected, _ := s.Where("Name = ?", "A").Delete()
if affected != 1 {
t.Fatal("failed to delete")
}
}
func TestSession_Count(t *testing.T) {
s := testInit(t)
c, _ := s.Where("Name = ?", "A").Count()
if c != 1 {
t.Fatal("failed to count")
}
}
Reference
Powered by Waline v2.15.2