3. 服务注册
...大约 4 分钟
1. 结构体映射为服务
对 net/rpc
而言,一个函数需要能够被远程调用,需要满足如下五个条件:
- the method’s type is exported. – 方法所属类型是导出的。
- the method is exported. – 方式是导出的。
- the method has two arguments, both exported (or builtin) types. – 两个入参,均为导出或内置类型。
- the method’s second argument is a pointer. – 第二个入参必须是一个指针。
- the method has return type error. – 返回值为 error 类型。
客户端发起的调用形式为ServiceMethod: ServiceName.MethodName
,不可能为所有的ServiceName
进行硬编码来构建实例,此时就需要使用反射动态的进行构建。
2. 通过反射实现 service
2.1 注册 service
geerpc/service.go
:
type methodType struct {
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint64
}
func (m *methodType) NumCalls() uint64 {
return atomic.LoadUint64(&m.numCalls)
}
func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
// arg may be pointer or value type
if m.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(m.ArgType.Elem())
} else {
argv = reflect.New(m.ArgType).Elem()
}
return argv
}
func (m *methodType) newReplyv() reflect.Value {
// reply must be pointer type
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}
methodType
:
method
:方法本身ArgType
,ReplyType
:参数类型numCalls
:方法调用次数
type service struct {
name string
typ reflect.Type
rcvr reflect.Value
method map[string]*methodType
}
name
:结构体名typ
:结构体类型rcvr
:结构体实例,即方法的接收者method
:注册的方法
func newService(rcvr any) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
s.typ = reflect.TypeOf(rcvr)
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}
func (s *service) registerMethods() {
s.method = make(map[string]*methodType)
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}
func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}
registerMethods
选择符合条件的方法:
- 两个入参
- 入参类型是导出类型或内置类型
- 一个返回值,类型为 error
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnVals := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnVals[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}
call
调用方法。
2.2 单元测试
package geerpc
import (
"fmt"
"reflect"
"testing"
)
type Foo int
type Args struct {
Num1, Num2 int
}
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func (f Foo) sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func _assert(condition bool, msg string, v ...any) {
if !condition {
panic(fmt.Sprintf("assertion failed: "+msg, v...))
}
}
func TestNewService(t *testing.T) {
var foo Foo
s := newService(&foo)
_assert(len(s.method) == 1, "wrong service method, expect 1, got %d", len(s.method))
mType := s.method["Sum"]
_assert(mType != nil, "wrong Method, Sum should not be nil")
}
func TestMethodType_Call(t *testing.T) {
var foo Foo
s := newService(&foo)
mType := s.method["Sum"]
argv := mType.newArgv()
replyv := mType.newReplyv()
argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
err := s.call(mType, argv, replyv)
_assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum")
}
2.3 server
geerpc/server.go
// Server represents an RPC server
type Server struct {
serviceMap sync.Map
}
// Register publishes in the server the set of methods of the
func (server *Server) Register(rcvr any) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc: service already defined:" + s.name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr any) error {
return DefaultServer.Register(rcvr)
}
Register
用于注册服务。
func (server *Server) findService(serviceMethod string) (*service, *methodType, error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err := errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return nil, nil, err
}
svcname, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := server.serviceMap.Load(svcname)
if !ok {
err := errors.New("rpc server: can't find service " + svcname)
return nil, nil, err
}
svc := svci.(*service)
mtype := svc.method[methodName]
if mtype == nil {
err := errors.New("rpc server: can't find method " + methodName)
return nil, nil, err
}
return svc, mtype, nil
}
findService
通过service.method
获取service
和methodType
实例。
type request struct {
h *codec.Header // header of request
argv, replyv reflect.Value // argv and replyv of request
mtype *methodType
svc *service
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
// make sure that argvi is a pointer
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err:", err)
return req, err
}
return req, nil
}
readRequest
流程:
- 通过
serviceMethod
调用findService
获取service
和methodType
实例 - 构建入参:
argv
和replyv
- 确保
argv
为指针类型,将消息体反序列化到argv
中
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
defer wg.Done()
log.Println(req.h, req.argv.Elem())
err := req.svc.call(req.mtype, req.argv, req.replyv)
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}
3. Demo
package main
import (
"geerpc"
"log"
"net"
"sync"
)
type Foo int
type Args struct {
Num1, Num2 int
}
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func startServer(addr chan string) {
var foo Foo
if err := geerpc.Register(&foo); err != nil {
log.Fatal("register error:", err)
}
// pick free port
lis, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", lis.Addr())
addr <- lis.Addr().String()
geerpc.Accept(lis)
}
func main() {
log.SetFlags(0)
addr := make(chan string)
go startServer(addr)
client, _ := geerpc.Dial("tcp", <-addr)
defer func() {
_ = client.Close()
}()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
args := &Args{Num1: i, Num2: 2 * i}
var reply int
if err := client.Call("Foo.Sum", args, &reply); err != nil {
log.Fatal("call Foo.Sum error:", err)
}
log.Printf("%d + %d = %d", args.Num1, args.Num2, reply)
}(i)
}
wg.Wait()
}
rpc server: register Foo.Sum
start rpc server on [::]:5657
2 + 4 = 6
4 + 8 = 12
0 + 0 = 0
3 + 6 = 9
1 + 2 = 3
Reference
Powered by Waline v2.15.2