用GO写一个RPC框架 s04 (编写服务端核心)
Posted dollarkillerx
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用GO写一个RPC框架 s04 (编写服务端核心)相关的知识,希望对你有一定的参考价值。
前言
通过上两篇的学习 我们已经了解了 服务端本地服务的注册, 服务端配置,协议 现在我们开始写服务端的核心逻辑
https://github.com/dollarkill...
默认配置
我们先看下默认的配置
func defaultOptions() *Options {
return &Options{
Protocol: transport.TCP, // default TCP
Uri: "0.0.0.0:8397",
UseHttp: false,
readTimeout: time.Minute * 3, // 心跳包 默认 3min
writeTimeout: time.Second * 30,
ctx: context.Background(), // ctx 是控制服务退出的
options: map[string]interface{}{
"TCPKeepAlivePeriod": time.Minute * 3,
},
processChanSize: 1000,
Trace: false,
RSAPublicKey: []byte(`-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
RSAPrivateKey: []byte(`-----BEGIN RSA PRIVATE KEY-----
-----END RSA PRIVATE KEY-----`),
Discovery: &discovery.SimplePeerToPeer{},
}
}
run
服务注册完毕之后 调用Run方法 启动服务
func (s *Server) Run(options ...Option) error {
// 初始化 服务端配置
for _, fn := range options {
fn(s.options)
}
var err error
// 更具配置传入的protocol 获取到 网络插件 (KCP UDP TCP) 我们等下细讲
s.options.nl, err = transport.Transport.Gen(s.options.Protocol, s.options.Uri)
if err != nil {
return err
}
log.Printf("LightRPC: %s %s \\n", s.options.Protocol, s.options.Uri)
// 这里是服务注册 我们这里先跳过
if s.options.Discovery != nil {
// 读取服务配置文件
sIdb, err := ioutil.ReadFile("./light.conf")
if err != nil {
// 如果没有 就生成 分布式ID
id, err := utils.DistributedID()
if err != nil {
return err
}
sIdb = []byte(id)
}
// 进行服务注册
sId := string(sIdb)
for k := range s.serviceMap { // 进行服务注册
err := s.options.Discovery.Registry(k, s.options.registryAddr, s.options.weights, s.options.Protocol, s.options.MaximumLoad, &sId)
if err != nil {
return err
}
log.Printf("Discovery Registry: %s addr: %s SUCCESS", k, s.options.registryAddr)
}
ioutil.WriteFile("./light.conf", sIdb, 00666)
}
// 启动服务
return s.run()
}
func (s *Server) run() error {
loop:
for {
select {
case <-s.options.ctx.Done(): // 检查是否需要退出服务
break loop
default:
accept, err := s.options.nl.Accept() // 获取一个链接
if err != nil {
log.Println(err)
continue
}
if s.options.Trace {
log.Println("connect: ", accept.RemoteAddr())
}
go s.process(accept) // 开一个协程去处理 该 链接
}
}
return nil
}
我们先回顾一下 上章讲的 握手逻辑
- 建立链接 通过非对称加密 传输 aes 密钥给服务端 (携带token)
- 服务端 验证 token 并记录 aes 密钥 后面与客户端交互 都采用对称加密
具体处理 链接 process (重点!!!)
func (s *Server) process(conn net.Conn) {
defer func() {
// 网络不可靠
if err := recover(); err != nil {
utils.PrintStack()
log.Println("Recover Err: ", err)
}
}()
// 每进来一个请求这里就ADD
s.options.Discovery.Add(1)
defer func() {
s.options.Discovery.Less(1) // 处理完 请求就退出
// 退出 回收句柄
err := conn.Close()
if err != nil {
log.Println(err)
return
}
if s.options.Trace {
log.Println("close connect: ", conn.RemoteAddr())
}
}()
// 这里定义一个xChannel 用于分离 请求和返回
xChannel := utils.NewXChannel(s.options.processChanSize)
// 握手
handshake := protocol.Handshake{}
err := handshake.Handshake(conn)
if err != nil {
return
}
// 非对称加密 解密 AES KEY
aesKey, err := cryptology.RsaDecrypt(handshake.Key, s.options.RSAPrivateKey)
if err != nil {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
conn.Write(encodeHandshake)
return
}
// 检测 AES KEY 是否正确
if len(aesKey) != 32 && len(aesKey) != 16 {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte("aes key != 32 && key != 16"))
conn.Write(encodeHandshake)
return
}
// 解密 TOKEN
token, err := cryptology.RsaDecrypt(handshake.Token, s.options.RSAPrivateKey)
if err != nil {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
conn.Write(encodeHandshake)
return
}
// 对TOKEN进行校验
if s.options.AuthFunc != nil {
err := s.options.AuthFunc(light.DefaultCtx(), string(token))
if err != nil {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
conn.Write(encodeHandshake)
return
}
}
// limit 限流
if s.options.Discovery.Limit() {
// 熔断
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(pkg.ErrCircuitBreaker.Error()))
conn.Write(encodeHandshake)
log.Println(s.options.Discovery.Limit())
return
}
// 如果握手没有问题 则返回握手成功
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(""))
_, err = conn.Write(encodeHandshake)
if err != nil {
return
}
// send
go func() {
loop:
for {
select {
// 这就是刚刚的xChannel 对读写进行分离
case msg, ex := <-xChannel.Ch:
if !ex {
if s.options.Trace {
log.Printf("ip: %s close send server", conn.RemoteAddr())
}
break loop
}
now := time.Now()
if s.options.writeTimeout > 0 {
conn.SetWriteDeadline(now.Add(s.options.writeTimeout))
}
// send message
_, err := conn.Write(msg)
if err != nil {
if s.options.Trace {
log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
}
break loop
}
}
}
}()
defer func() {
xChannel.Close()
}()
loop:
for { // 具体消息获取
now := time.Now()
if s.options.readTimeout > 0 {
conn.SetReadDeadline(now.Add(s.options.readTimeout))
}
proto := protocol.NewProtocol()
msg, err := proto.IODecode(conn) // 获取一个消息
if err != nil {
if err == io.EOF {
if s.options.Trace {
log.Printf("ip: %s close", conn.RemoteAddr())
}
break loop
}
// 遇到错误关闭链接
if s.options.Trace {
log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
}
break loop
}
go s.processResponse(xChannel, msg, conn.RemoteAddr().String(), aesKey)
}
}
具体处理 (重点!!!)
注意此RPC传输消息都是编码过的 要进行转码
- 第一层 为压缩编码
- 第二层 为加密编码
- 第三层 为序列化
func (s *Server) processResponse(xChannel *utils.XChannel, msg *protocol.Message, addr string, aesKey []byte) {
var err error
s.options.Discovery.Add(1)
defer func() {
s.options.Discovery.Less(1)
if err != nil {
if s.options.Trace {
log.Println("ProcessResponse Error: ", err, " ID: ", addr)
}
xChannel.Close()
}
}()
// heartBeat 判断
if msg.Header.RespType == byte(protocol.HeartBeat) {
// 心跳返回
if s.options.Trace {
log.Println("HeartBeat: ", addr)
}
// 4. 打包
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), []byte(""), byte(protocol.HeartBeat), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}
return
}
// 限流
if s.options.Discovery.Limit() {
serialization, _ := codes.SerializationManager.Get(codes.MsgPack)
metaData := make(map[string]string)
metaData["RespError"] = pkg.ErrCircuitBreaker.Error()
meta, err := serialization.Encode(metaData)
if err != nil {
return
}
decrypt, err := cryptology.AESDecrypt(aesKey, meta)
if err != nil {
return
}
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), decrypt, byte(protocol.Response), byte(codes.RawData), byte(codes.MsgPack), []byte(""))
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}
log.Println(s.options.Discovery.Limit())
log.Println("限流/////////////")
return
}
// 1. 解压缩
compressor, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
if !ex {
err = errors.New("compressor 404")
return
}
msg.MetaData, err = compressor.Unzip(msg.MetaData)
if err != nil {
return
}
msg.Payload, err = compressor.Unzip(msg.Payload)
if err != nil {
return
}
// 2. 解密
msg.MetaData, err = cryptology.AESDecrypt(aesKey, msg.MetaData)
if err != nil {
return
}
msg.Payload, err = cryptology.AESDecrypt(aesKey, msg.Payload)
if err != nil {
return
}
// 3. 反序列化
serialization, ex := codes.SerializationManager.Get(codes.SerializationType(msg.Header.SerializationType))
if !ex {
err = errors.New("serialization 404")
return
}
metaData := make(map[string]string)
err = serialization.Decode(msg.MetaData, &metaData)
if err != nil {
return
}
// 初始化context
ctx := light.DefaultCtx()
ctx.SetMetaData(metaData)
// 1.3 auth
if s.options.AuthFunc != nil {
auth := metaData["Light_AUTH"]
err := s.options.AuthFunc(ctx, auth)
if err != nil {
ctx.SetValue("RespError", err.Error())
var metaDataByte []byte
metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
metaDataByte, _ = cryptology.AESEncrypt(aesKey, metaDataByte)
metaDataByte, _ = compressor.Zip(metaDataByte)
// 4. 打包
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}
return
}
}
// 找到具体调用的服务
ser, ex := s.serviceMap[msg.ServiceName]
if !ex {
err = errors.New("service does not exist")
return
}
// 找到具体调用的方法
method, ex := ser.methodType[msg.ServiceMethod]
if !ex {
err = errors.New("method does not exist")
return
}
// 初始化 req, resp
req := utils.RefNew(method.RequestType)
resp := utils.RefNew(method.ResponseType)
err = serialization.Decode(msg.Payload, req)
if err != nil {
return
}
// 定义ctx paht 为 服务名称.服务方法
path := fmt.Sprintf("%s.%s", msg.ServiceName, msg.ServiceMethod)
ctx.SetPath(path)
// 前置middleware
if len(s.beforeMiddleware) != 0 {
for idx := range s.beforeMiddleware {
err := s.beforeMiddleware[idx](ctx, req, resp)
if err != nil {
return
}
}
}
funcs, ex := s.beforeMiddlewarePath[path]
if ex {
if len(funcs) != 0 {
for idx := range funcs {
err := funcs[idx](ctx, req, resp)
if err != nil {
return
}
}
}
}
// 核心调用
callErr := ser.call(ctx, method, reflect.ValueOf(req), reflect.ValueOf(resp))
if callErr != nil {
ctx.SetValue("RespError", callErr.Error())
}
// 后置middleware
if len(s.afterMiddleware) != 0 {
for idx := range s.afterMiddleware {
err := s.afterMiddleware[idx](ctx, req, resp)
if err != nil {
return
}
}
}
funcs, ex = s.afterMiddlewarePath[path]
if ex {
if len(funcs) != 0 {
for idx := range funcs {
err := funcs[idx](ctx, req, resp)
if err != nil {
return
}
}
}
}
// response
// 1. 序列化
var respBody []byte
respBody, err = serialization.Encode(resp)
var metaDataByte []byte
metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
// 2. 加密
metaDataByte, err = cryptology.AESEncrypt(aesKey, metaDataByte)
if err != nil {
return
}
respBody, err = cryptology.AESEncrypt(aesKey, respBody)
if err != nil {
return
}
// 3. 压缩
metaDataByte, err = compressor.Zip(metaDataByte)
if err != nil {
return
}
respBody, err = compressor.Zip(respBody)
if err != nil {
return
}
// 4. 打包
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, respBody)
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}
}
调用具体方法
func (s *service) call(ctx *light.Context, mType *methodType, request, response reflect.Value) (err error) {
// recover 捕获堆栈消息
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
buf = buf[:n]
err = fmt.Errorf("[painc service internal error]: %v, method: %s, argv: %+v, stack: %s",
r, mType.method.Name, request.Interface(), buf)
log.Println(err)
}
}()
fn := mType.method.Func
returnValue := fn.Call([]reflect.Value{s.refVal, reflect.ValueOf(ctx), request, response})
errInterface := returnValue[0].Interface()
if errInterface != nil {
return errInterface.(error)
}
return nil
}
这里就完成了服务端的基础逻辑了
以上是关于用GO写一个RPC框架 s04 (编写服务端核心)的主要内容,如果未能解决你的问题,请参考以下文章