3. 新增和查询
...大约 3 分钟
1. Clause 构造 SQL
查询语句一般由多个子句(clause)组成,例如:
SELECT col1, clo2, ...
FROM table_name
WHERE [condition]
GROUP BY col1
HAVING [condition]
一次性构造出完整的SQL语句较为复杂,可以将构造SQL的部分独立出来放在clause
包中。
clause/genreator.go
定义SQL子句的生成规则。
package clause
import (
"fmt"
"strings"
)
type generator func(vals ...any) (string, []any)
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(vals ...any) (string, []any) {
// INSERT INTO $tableName ($fields)
tableName := vals[0]
fields := strings.Join(vals[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []any{}
}
func _values(vals ...any) (string, []any) {
// VALUES ($1), ($2) ...
var (
bindStr string
sql strings.Builder
vars []any
)
sql.WriteString("VALUES ")
for i, val := range vals {
v := val.([]any)
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(vals) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(vals ...any) (string, []any) {
// SELECT $fields FROM $tableName
tableName := vals[0]
fields := strings.Join(vals[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []any{}
}
func _limit(vals ...any) (string, []any) {
// LIMIT $num
return "LIMIT ?", vals
}
func _where(vals ...any) (string, []any) {
// WHERE $desc
desc, vars := vals[0], vals[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(vals ...any) (string, []any) {
return fmt.Sprintf("ORDER BY %s", vals[0]), []any{}
}
clause/clause.go
中拼接各子句。
package clause
import "strings"
type Clause struct {
sql map[Type]string
sqlVars map[Type][]any
}
type Type int
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
)
func (c *Clause) Set(name Type, vars ...any) {
if c.sql == nil { // 延迟加载
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]any)
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
func (c *Clause) Build(orders ...Type) (string, []any) {
var (
sqls []string
vars []any
)
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
Set
:根据类型生成子句Build
:根据子句的顺序,构造出完整的 SQL 语句
1.1 单元测试
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []any{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
}
2. Insert 实现
Insert 可以将对象的字段值插入到对应的数据表中。
2.1 schema
schema.Schema
负责对象和数据表的映射,新增方法用于将对象转换成INSERT
操作的VALUE
。
func (s *Schema) RecordValue(dst any) []any {
dstVal := reflect.Indirect(reflect.ValueOf(dst))
var fieldVals []any
for _, field := range s.Fields {
fieldVals = append(fieldVals, dstVal.FieldByName(field.Name).Interface())
}
return fieldVals
}
2.2 session
修改session.Session
数据结构
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []any
}
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
新建session/record.go
用于实现记录的增删改查
func (s *Session) Insert(vals ...any) (int64, error) {
recordVals := make([]any, 0)
for _, val := range vals {
table := s.Model(val).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordVals = append(recordVals, table.RecordValue(val))
}
s.clause.Set(clause.VALUES, recordVals...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
res, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return res.RowsAffected()
}
3. Find 实现
Find 会将查询结果放入对象切片中,传入的是切片的指针,例如:
s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
var users []User
s.Find(&users)
3.1 session
session/record.go
func (s *Session) Find(vals any) error {
dstSlice := reflect.Indirect(reflect.ValueOf(vals))
dstType := dstSlice.Type().Elem()
table := s.Model(reflect.New(dstType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dst := reflect.New(dstType).Elem()
var fieldVals []any
for _, name := range table.FieldNames {
fieldVals = append(fieldVals, dst.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(fieldVals...); err != nil {
return err
}
dstSlice.Set(reflect.Append(dstSlice, dst))
}
return rows.Close()
}
- 通过反射获取切片指针指向的切片
- 获取切片元素的类型
- 构建切片元素类型的实例,获取表结构
- 构造查询语句并查询
- 遍历查询结果,构建切片元素实例,存储查询结果并追加至切片中
3.2 单元测试
var (
user1 = &User{"A", 1}
user2 = &User{"B", 2}
user3 = &User{"C", 3}
)
func testInit(t *testing.T) *Session {
t.Helper()
s := newSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed to init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query records")
}
}
Reference
Powered by Waline v2.15.2