用GO写一个RPC框架 s04 (编写服务端核心)

Posted dollarkillerx

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用GO写一个RPC框架 s04 (编写服务端核心)相关的知识,希望对你有一定的参考价值。

前言

通过上两篇的学习 我们已经了解了 服务端本地服务的注册, 服务端配置,协议 现在我们开始写服务端的核心逻辑

https://github.com/dollarkill...

默认配置

我们先看下默认的配置

func defaultOptions() *Options {
    return &Options{
        Protocol:     transport.TCP, // default TCP
        Uri:          "0.0.0.0:8397",
        UseHttp:      false,
        readTimeout:  time.Minute * 3, // 心跳包 默认 3min
        writeTimeout: time.Second * 30,
        ctx:          context.Background(), // ctx 是控制服务退出的
        options: map[string]interface{}{
            "TCPKeepAlivePeriod": time.Minute * 3,
        },
        processChanSize: 1000,    
        Trace:           false,
        RSAPublicKey: []byte(`-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
        RSAPrivateKey: []byte(`-----BEGIN RSA PRIVATE KEY-----
-----END RSA PRIVATE KEY-----`),
        Discovery: &discovery.SimplePeerToPeer{},
    }
}

run

服务注册完毕之后 调用Run方法 启动服务

func (s *Server) Run(options ...Option) error {
        // 初始化 服务端配置
    for _, fn := range options {
        fn(s.options)
    }

    var err error
        // 更具配置传入的protocol 获取到 网络插件 (KCP UDP TCP) 我们等下细讲
    s.options.nl, err = transport.Transport.Gen(s.options.Protocol, s.options.Uri)
    if err != nil {
        return err
    }

    log.Printf("LightRPC: %s  %s \\n", s.options.Protocol, s.options.Uri)

        // 这里是服务注册 我们这里先跳过  
    if s.options.Discovery != nil {
                // 读取服务配置文件
        sIdb, err := ioutil.ReadFile("./light.conf")
        if err != nil {
                        // 如果没有 就生成 分布式ID
            id, err := utils.DistributedID()
            if err != nil {
                return err
            }
            sIdb = []byte(id)
        }
        // 进行服务注册
        sId := string(sIdb)
        for k := range s.serviceMap {   // 进行服务注册 
            err := s.options.Discovery.Registry(k, s.options.registryAddr, s.options.weights, s.options.Protocol, s.options.MaximumLoad, &sId)
            if err != nil {
                return err
            }
            log.Printf("Discovery Registry: %s addr: %s SUCCESS", k, s.options.registryAddr)
        }

        ioutil.WriteFile("./light.conf", sIdb, 00666)
    }
        
        // 启动服务
    return s.run()
}



func (s *Server) run() error {
loop:
    for {
        select {
        case <-s.options.ctx.Done():  // 检查是否需要退出服务
            break loop
        default:
            accept, err := s.options.nl.Accept() // 获取一个链接
            if err != nil {
                log.Println(err)
                continue
            }
            if s.options.Trace {
                log.Println("connect: ", accept.RemoteAddr())
            }

            go s.process(accept) // 开一个协程去处理 该 链接
        }

    }

    return nil
}

我们先回顾一下 上章讲的 握手逻辑

  1. 建立链接 通过非对称加密 传输 aes 密钥给服务端 (携带token)
  2. 服务端 验证 token 并记录 aes 密钥 后面与客户端交互 都采用对称加密

具体处理 链接 process (重点!!!)

func (s *Server) process(conn net.Conn) {

    defer func() {
        // 网络不可靠
        if err := recover(); err != nil {
            utils.PrintStack()
            log.Println("Recover Err: ", err)
        }
    }()

        // 每进来一个请求这里就ADD
    s.options.Discovery.Add(1)
    defer func() {
        s.options.Discovery.Less(1) // 处理完 请求就退出
        // 退出 回收句柄
        err := conn.Close()  
        if err != nil {
            log.Println(err)
            return
        }

        if s.options.Trace {
            log.Println("close connect: ", conn.RemoteAddr())
        }
    }()

        // 这里定义一个xChannel 用于分离 请求和返回
    xChannel := utils.NewXChannel(s.options.processChanSize)

    // 握手
    handshake := protocol.Handshake{}
    err := handshake.Handshake(conn)
    if err != nil {
        return
    }
            
        // 非对称加密  解密 AES KEY
    aesKey, err := cryptology.RsaDecrypt(handshake.Key, s.options.RSAPrivateKey)
    if err != nil {
        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
        conn.Write(encodeHandshake)
        return
    }

        // 检测 AES KEY 是否正确
    if len(aesKey) != 32 && len(aesKey) != 16 {
        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte("aes key != 32 && key != 16"))
        conn.Write(encodeHandshake)
        return
    }
        
        // 解密 TOKEN
    token, err := cryptology.RsaDecrypt(handshake.Token, s.options.RSAPrivateKey)
    if err != nil {
        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
        conn.Write(encodeHandshake)
        return
    }
        // 对TOKEN进行校验  
    if s.options.AuthFunc != nil {
        err := s.options.AuthFunc(light.DefaultCtx(), string(token))
        if err != nil {
            encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
            conn.Write(encodeHandshake)
            return
        }
    }

    // limit 限流
    if s.options.Discovery.Limit() {
        // 熔断
        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(pkg.ErrCircuitBreaker.Error()))
        conn.Write(encodeHandshake)
        log.Println(s.options.Discovery.Limit())
        return
    }
        
        // 如果握手没有问题 则返回握手成功
    encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(""))
    _, err = conn.Write(encodeHandshake)
    if err != nil {
        return
    }
        
    // send
    go func() {
    loop:
        for {
            select {
                        // 这就是刚刚的xChannel 对读写进行分离
            case msg, ex := <-xChannel.Ch: 
                if !ex {
                    if s.options.Trace {
                        log.Printf("ip: %s  close send server", conn.RemoteAddr())
                    }
                    break loop
                }
                now := time.Now()
                if s.options.writeTimeout > 0 {
                    conn.SetWriteDeadline(now.Add(s.options.writeTimeout))
                }
                // send message
                _, err := conn.Write(msg)
                if err != nil {
                    if s.options.Trace {
                        log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
                    }
                    break loop
                }
            }
        }
    }()

    defer func() {
        xChannel.Close()
    }()
loop:
    for { // 具体消息获取
        now := time.Now()
        if s.options.readTimeout > 0 {
            conn.SetReadDeadline(now.Add(s.options.readTimeout))
        }

        proto := protocol.NewProtocol()
        msg, err := proto.IODecode(conn) // 获取一个消息
        if err != nil {
            if err == io.EOF {
                if s.options.Trace {
                    log.Printf("ip: %s close", conn.RemoteAddr())
                }
                break loop
            }

            // 遇到错误关闭链接
            if s.options.Trace {
                log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
            }
            break loop
        }

        go s.processResponse(xChannel, msg, conn.RemoteAddr().String(), aesKey)
    }
}

具体处理 (重点!!!)

注意此RPC传输消息都是编码过的 要进行转码

  • 第一层 为压缩编码
  • 第二层 为加密编码
  • 第三层 为序列化
func (s *Server) processResponse(xChannel *utils.XChannel, msg *protocol.Message, addr string, aesKey []byte) {
    var err error
    s.options.Discovery.Add(1)
    defer func() {
        s.options.Discovery.Less(1)
        if err != nil {
            if s.options.Trace {
                log.Println("ProcessResponse Error: ", err, "  ID: ", addr)
            }
            xChannel.Close()
        }
    }()

    // heartBeat 判断
    if msg.Header.RespType == byte(protocol.HeartBeat) {
        // 心跳返回
        if s.options.Trace {
            log.Println("HeartBeat: ", addr)
        }

        // 4. 打包
        _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), []byte(""), byte(protocol.HeartBeat), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
        if err != nil {
            return
        }
        // 5. 回写
        err = xChannel.Send(message)
        if err != nil {
            return
        }

        return
    }

    // 限流
    if s.options.Discovery.Limit() {
        serialization, _ := codes.SerializationManager.Get(codes.MsgPack)
        metaData := make(map[string]string)
        metaData["RespError"] = pkg.ErrCircuitBreaker.Error()
        meta, err := serialization.Encode(metaData)
        if err != nil {
            return
        }
        decrypt, err := cryptology.AESDecrypt(aesKey, meta)
        if err != nil {
            return
        }
        _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), decrypt, byte(protocol.Response), byte(codes.RawData), byte(codes.MsgPack), []byte(""))
        if err != nil {
            return
        }
        // 5. 回写
        err = xChannel.Send(message)
        if err != nil {
            return
        }

        log.Println(s.options.Discovery.Limit())
        log.Println("限流/////////////")

        return
    }

    // 1. 解压缩
    compressor, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
    if !ex {
        err = errors.New("compressor 404")
        return
    }
    msg.MetaData, err = compressor.Unzip(msg.MetaData)
    if err != nil {
        return
    }

    msg.Payload, err = compressor.Unzip(msg.Payload)
    if err != nil {
        return
    }
    // 2. 解密
    msg.MetaData, err = cryptology.AESDecrypt(aesKey, msg.MetaData)
    if err != nil {
        return
    }

    msg.Payload, err = cryptology.AESDecrypt(aesKey, msg.Payload)
    if err != nil {
        return
    }

    // 3. 反序列化
    serialization, ex := codes.SerializationManager.Get(codes.SerializationType(msg.Header.SerializationType))
    if !ex {
        err = errors.New("serialization 404")
        return
    }

    metaData := make(map[string]string)
    err = serialization.Decode(msg.MetaData, &metaData)
    if err != nil {
        return
    }

        // 初始化context
    ctx := light.DefaultCtx()
    ctx.SetMetaData(metaData)

    // 1.3 auth
    if s.options.AuthFunc != nil {
        auth := metaData["Light_AUTH"]
        err := s.options.AuthFunc(ctx, auth)
        if err != nil {
            ctx.SetValue("RespError", err.Error())
            var metaDataByte []byte
            metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
            metaDataByte, _ = cryptology.AESEncrypt(aesKey, metaDataByte)
            metaDataByte, _ = compressor.Zip(metaDataByte)
            // 4. 打包
            _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
            if err != nil {
                return
            }
            // 5. 回写
            err = xChannel.Send(message)
            if err != nil {
                return
            }
            return
        }
    }

        // 找到具体调用的服务
    ser, ex := s.serviceMap[msg.ServiceName]
    if !ex {
        err = errors.New("service does not exist")
        return
    }

        // 找到具体调用的方法
    method, ex := ser.methodType[msg.ServiceMethod]
    if !ex {
        err = errors.New("method does not exist")
        return
    }

        // 初始化 req, resp
    req := utils.RefNew(method.RequestType)
    resp := utils.RefNew(method.ResponseType)

    err = serialization.Decode(msg.Payload, req)
    if err != nil {
        return
    }

        // 定义ctx paht 为   服务名称.服务方法
    path := fmt.Sprintf("%s.%s", msg.ServiceName, msg.ServiceMethod)
    ctx.SetPath(path)

    // 前置middleware
    if len(s.beforeMiddleware) != 0 {
        for idx := range s.beforeMiddleware {
            err := s.beforeMiddleware[idx](ctx, req, resp)
            if err != nil {
                return
            }
        }
    }
    funcs, ex := s.beforeMiddlewarePath[path]
    if ex {
        if len(funcs) != 0 {
            for idx := range funcs {
                err := funcs[idx](ctx, req, resp)
                if err != nil {
                    return
                }
            }
        }
    }

    // 核心调用
    callErr := ser.call(ctx, method, reflect.ValueOf(req), reflect.ValueOf(resp))
    if callErr != nil {
        ctx.SetValue("RespError", callErr.Error())
    }

    // 后置middleware
    if len(s.afterMiddleware) != 0 {
        for idx := range s.afterMiddleware {
            err := s.afterMiddleware[idx](ctx, req, resp)
            if err != nil {
                return
            }
        }
    }
    funcs, ex = s.afterMiddlewarePath[path]
    if ex {
        if len(funcs) != 0 {
            for idx := range funcs {
                err := funcs[idx](ctx, req, resp)
                if err != nil {
                    return
                }
            }
        }
    }
    // response

    // 1. 序列化
    var respBody []byte
    respBody, err = serialization.Encode(resp)

    var metaDataByte []byte
    metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
    // 2. 加密
    metaDataByte, err = cryptology.AESEncrypt(aesKey, metaDataByte)
    if err != nil {
        return
    }
    respBody, err = cryptology.AESEncrypt(aesKey, respBody)
    if err != nil {
        return
    }
    // 3. 压缩
    metaDataByte, err = compressor.Zip(metaDataByte)
    if err != nil {
        return
    }
    respBody, err = compressor.Zip(respBody)
    if err != nil {
        return
    }
    // 4. 打包
    _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, respBody)
    if err != nil {
        return
    }
    // 5. 回写
    err = xChannel.Send(message)
    if err != nil {
        return
    }
}

调用具体方法

func (s *service) call(ctx *light.Context, mType *methodType, request, response reflect.Value) (err error) {
    // recover 捕获堆栈消息
    defer func() {
        if r := recover(); r != nil {
            buf := make([]byte, 4096)
            n := runtime.Stack(buf, false)
            buf = buf[:n]

            err = fmt.Errorf("[painc service internal error]: %v, method: %s, argv: %+v, stack: %s",
                r, mType.method.Name, request.Interface(), buf)
            log.Println(err)
        }
    }()

    fn := mType.method.Func
    returnValue := fn.Call([]reflect.Value{s.refVal, reflect.ValueOf(ctx), request, response})
    errInterface := returnValue[0].Interface()
    if errInterface != nil {
        return errInterface.(error)
    }

    return nil
}

这里就完成了服务端的基础逻辑了

以上是关于用GO写一个RPC框架 s04 (编写服务端核心)的主要内容,如果未能解决你的问题,请参考以下文章

用GO写一个RPC框架 s05 (客户端编写)

用GO写一个RPC框架 s03 (协议设计)

用GO写一个RPC框架 s01(服务内部注册实现)

Go 每日一库之 rpc

一个 Go 语言写的微服务后端管理系统

摸清 Go RPC 原理的第一步!