3. 新增和查询

Kesa...大约 3 分钟golang

day3-insert-queryopen in new window

1. Clause 构造 SQL

查询语句一般由多个子句(clause)组成,例如:

SELECT col1, clo2, ...
	FROM table_name
	WHERE [condition]
	GROUP BY col1
	HAVING [condition]

一次性构造出完整的SQL语句较为复杂,可以将构造SQL的部分独立出来放在clause包中。

clause/genreator.go定义SQL子句的生成规则。

package clause

import (
	"fmt"
	"strings"
)

type generator func(vals ...any) (string, []any)

var generators map[Type]generator

func init() {
	generators = make(map[Type]generator)
	generators[INSERT] = _insert
	generators[VALUES] = _values
	generators[SELECT] = _select
	generators[LIMIT] = _limit
	generators[WHERE] = _where
	generators[ORDERBY] = _orderBy
}

func genBindVars(num int) string {
	var vars []string
	for i := 0; i < num; i++ {
		vars = append(vars, "?")
	}
	return strings.Join(vars, ", ")
}

func _insert(vals ...any) (string, []any) {
	// INSERT INTO $tableName ($fields)
	tableName := vals[0]
	fields := strings.Join(vals[1].([]string), ",")
	return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []any{}
}

func _values(vals ...any) (string, []any) {
	// VALUES ($1), ($2) ...
	var (
		bindStr string
		sql     strings.Builder
		vars    []any
	)

	sql.WriteString("VALUES ")
	for i, val := range vals {
		v := val.([]any)
		if bindStr == "" {
			bindStr = genBindVars(len(v))
		}
		sql.WriteString(fmt.Sprintf("(%v)", bindStr))
		if i+1 != len(vals) {
			sql.WriteString(", ")
		}
		vars = append(vars, v...)
	}
	return sql.String(), vars
}

func _select(vals ...any) (string, []any) {
	// SELECT $fields FROM $tableName
	tableName := vals[0]
	fields := strings.Join(vals[1].([]string), ",")
	return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []any{}
}

func _limit(vals ...any) (string, []any) {
	// LIMIT $num
	return "LIMIT ?", vals
}

func _where(vals ...any) (string, []any) {
	// WHERE $desc
	desc, vars := vals[0], vals[1:]
	return fmt.Sprintf("WHERE %s", desc), vars
}

func _orderBy(vals ...any) (string, []any) {
	return fmt.Sprintf("ORDER BY %s", vals[0]), []any{}
}

clause/clause.go中拼接各子句。

package clause

import "strings"

type Clause struct {
	sql     map[Type]string
	sqlVars map[Type][]any
}

type Type int

const (
	INSERT Type = iota
	VALUES
	SELECT
	LIMIT
	WHERE
	ORDERBY
)

func (c *Clause) Set(name Type, vars ...any) {
	if c.sql == nil { // 延迟加载
		c.sql = make(map[Type]string)
		c.sqlVars = make(map[Type][]any)
	}
	sql, vars := generators[name](vars...)
	c.sql[name] = sql
	c.sqlVars[name] = vars
}

func (c *Clause) Build(orders ...Type) (string, []any) {
	var (
		sqls []string
		vars []any
	)
	for _, order := range orders {
		if sql, ok := c.sql[order]; ok {
			sqls = append(sqls, sql)
			vars = append(vars, c.sqlVars[order]...)
		}
	}
	return strings.Join(sqls, " "), vars
}

  • Set:根据类型生成子句
  • Build:根据子句的顺序,构造出完整的 SQL 语句

1.1 单元测试

func testSelect(t *testing.T) {
	var clause Clause
	clause.Set(LIMIT, 3)
	clause.Set(SELECT, "User", []string{"*"})
	clause.Set(WHERE, "Name = ?", "Tom")
	clause.Set(ORDERBY, "Age ASC")

	sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
	t.Log(sql, vars)
	if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
		t.Fatal("failed to build SQL")
	}
	if !reflect.DeepEqual(vars, []any{"Tom", 3}) {
		t.Fatal("failed to build SQLVars")
	}
}

func TestClause_Build(t *testing.T) {
	t.Run("select", func(t *testing.T) {
		testSelect(t)
	})
}

2. Insert 实现

Insert 可以将对象的字段值插入到对应的数据表中。

2.1 schema

schema.Schema负责对象和数据表的映射,新增方法用于将对象转换成INSERT操作的VALUE

func (s *Schema) RecordValue(dst any) []any {
	dstVal := reflect.Indirect(reflect.ValueOf(dst))
	var fieldVals []any
	for _, field := range s.Fields {
		fieldVals = append(fieldVals, dstVal.FieldByName(field.Name).Interface())
	}
	return fieldVals
}

2.2 session

修改session.Session数据结构

type Session struct {
	db       *sql.DB
	dialect  dialect.Dialect
	refTable *schema.Schema
	clause   clause.Clause
	sql      strings.Builder
	sqlVars  []any
}

func (s *Session) Clear() {
	s.sql.Reset()
	s.sqlVars = nil
	s.clause = clause.Clause{}
}

新建session/record.go用于实现记录的增删改查

func (s *Session) Insert(vals ...any) (int64, error) {
	recordVals := make([]any, 0)
	for _, val := range vals {
		table := s.Model(val).RefTable()
		s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
		recordVals = append(recordVals, table.RecordValue(val))
	}

	s.clause.Set(clause.VALUES, recordVals...)
	sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
	res, err := s.Raw(sql, vars...).Exec()
	if err != nil {
		return 0, err
	}
	return res.RowsAffected()
}

3. Find 实现

Find 会将查询结果放入对象切片中,传入的是切片的指针,例如:

s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
var users []User
s.Find(&users)

3.1 session

session/record.go

func (s *Session) Find(vals any) error {
	dstSlice := reflect.Indirect(reflect.ValueOf(vals))
	dstType := dstSlice.Type().Elem()
	table := s.Model(reflect.New(dstType).Elem().Interface()).RefTable()

	s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
	sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
	rows, err := s.Raw(sql, vars...).QueryRows()
	if err != nil {
		return err
	}

	for rows.Next() {
		dst := reflect.New(dstType).Elem()
		var fieldVals []any
		for _, name := range table.FieldNames {
			fieldVals = append(fieldVals, dst.FieldByName(name).Addr().Interface())
		}
		if err := rows.Scan(fieldVals...); err != nil {
			return err
		}
		dstSlice.Set(reflect.Append(dstSlice, dst))
	}
	return rows.Close()
}
  1. 通过反射获取切片指针指向的切片
  2. 获取切片元素的类型
  3. 构建切片元素类型的实例,获取表结构
  4. 构造查询语句并查询
  5. 遍历查询结果,构建切片元素实例,存储查询结果并追加至切片中

3.2 单元测试

var (
	user1 = &User{"A", 1}
	user2 = &User{"B", 2}
	user3 = &User{"C", 3}
)

func testInit(t *testing.T) *Session {
	t.Helper()
	s := newSession().Model(&User{})
	err1 := s.DropTable()
	err2 := s.CreateTable()
	_, err3 := s.Insert(user1, user2)
	if err1 != nil || err2 != nil || err3 != nil {
		t.Fatal("failed to init test records")
	}
	return s
}

func TestSession_Insert(t *testing.T) {
	s := testInit(t)
	affected, err := s.Insert(user3)
	if err != nil || affected != 1 {
		t.Fatal("failed to create record")
	}
}

func TestSession_Find(t *testing.T) {
	s := testInit(t)
	var users []User
	if err := s.Find(&users); err != nil || len(users) != 2 {
		t.Fatal("failed to query records")
	}
}

Reference

  1. https://geektutu.com/post/geeorm-day3.htmlopen in new window
上次编辑于:
评论
  • 按正序
  • 按倒序
  • 按热度
Powered by Waline v2.15.2