用GO写一个RPC框架 s05 (客户端编写)
Posted dollarkillerx
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用GO写一个RPC框架 s05 (客户端编写)相关的知识,希望对你有一定的参考价值。
前言
前面几章我们完成了 服务端的编写 现在开始客户端编写
https://github.com/dollarkill...
Client
type Client struct {
options *Options
}
func NewClient(discover discovery.Discovery, options ...Option) *Client {
client := &Client{
options: defaultOptions(),
}
client.options.Discovery = discover
for _, fn := range options {
fn(client.options)
}
return client
}
option
type Options struct {
Discovery discovery.Discovery // 服务发现插件
loadBalancing load_banlancing.LoadBalancing // 负载均衡插件
serializationType codes.SerializationType // 序列化插件
compressorType codes.CompressorType // 压缩插件
pool int // 连接池大小
cryptology cryptology.Cryptology
rsaPublicKey []byte
writeTimeout time.Duration
readTimeout time.Duration
heartBeat time.Duration
Trace bool
AUTH string // AUTH TOKEN
}
func defaultOptions() *Options {
defaultPoolSize := runtime.NumCPU() * 4
if defaultPoolSize < 20 {
defaultPoolSize = 20
}
return &Options{
pool: defaultPoolSize,
serializationType: codes.MsgPack,
compressorType: codes.Snappy,
loadBalancing: load_banlancing.NewPolling(),
cryptology: cryptology.AES,
rsaPublicKey: []byte(`
-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
writeTimeout: time.Minute,
readTimeout: time.Minute * 3,
heartBeat: time.Minute,
Trace: false,
AUTH: "",
}
}
具体每个链接
type Connect struct {
Client *Client
pool *connectPool
close chan struct{}
serverName string
}
func (c *Client) NewConnect(serverName string) (conn *Connect, err error) {
connect := &Connect{
Client: c,
serverName: serverName,
close: make(chan struct{}),
}
connect.pool, err = initPool(connect)
return connect, err
}
初始化连接池
func initPool(c *Connect) (*connectPool, error) {
cp := &connectPool{
connect: c,
pool: make(chan LightClient, c.Client.options.pool),
}
return cp, cp.initPool()
}
func (c *connectPool) initPool() error {
hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName) // 调用服务发现 查看 发现具体服务
if err != nil {
return err
}
if len(hosts) == 0 {
return errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
}
c.connect.Client.options.loadBalancing.InitBalancing(hosts) // 初始化 负载均衡插件
// 初始化连接池
for i := 0; i < c.connect.Client.options.pool; i++ {
client, err := newBaseClient(c.connect.serverName, c.connect.Client.options) // 建立链接
if err != nil {
return errors.WithStack(err)
}
c.pool <- client
}
return nil
}
// 连接池中获取一个链接
func (c *connectPool) Get(ctx context.Context) (LightClient, error) {
select {
case <-ctx.Done():
return nil, errors.New("pool get timeout")
case r := <-c.pool:
return r, nil
}
}
// 放回一个链接
func (c *connectPool) Put(client LightClient) {
if client.Error() == nil {
c.pool <- client
return
}
// 如果 client.Error() 有异常 需要新初始化一个链接 放入连接池
go func() {
fmt.Println("The server starts to restore")
for {
time.Sleep(time.Second)
hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName)
if err != nil {
log.Println(err)
continue
}
if len(hosts) == 0 {
err := errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
log.Println(err)
continue
}
c.connect.Client.options.loadBalancing.InitBalancing(hosts)
baseClient, err := newBaseClient(c.connect.serverName, c.connect.Client.options)
if err != nil {
log.Println(err)
continue
}
c.pool <- baseClient
fmt.Println("Service recovery success")
break
}
}()
}
Connect 调用具体服务
func (c *Connect) Call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}) error {
ctxT, _ := context.WithTimeout(context.TODO(), time.Second*6)
var err error
// 连接池中获取一个链接
client, err := c.pool.Get(ctxT)
if err != nil {
return errors.WithStack(err)
}
// 用完 放回链接
defer func() {
c.pool.Put(client)
}()
// 设置token
ctx.SetValue("Light_AUTH", c.Client.options.AUTH)
// 具体调用
err = client.Call(ctx, serviceMethod, request, response)
if err != nil {
return errors.WithStack(err)
}
return nil
}
调用核心 重点
复习 s03 协议设计
/**
协议设计
起始符 : 版本号 : crc32校验 : magicNumberSize: serverNameSize : serverMethodSize : metaDataSize : payloadSize: respType : compressorType : serializationType : magicNumber : serverName : serverMethod : metaData : payload
0x05 : 0x01 : 4 : 4 : 4 : 4 : 4 : 4 : 1 : 1 : 1 : xxx : xxx : xxx : xxx : xxx
*/
注意: 每一个请求都有一个 magicNumber 都有一个请求ID
单个链接定义
type BaseClient struct {
conn net.Conn
options *Options
serverName string
aesKey []byte
serialization codes.Serialization
compressor codes.Compressor
respInterMap map[string]*respMessage
respInterRM sync.RWMutex // 返回结构锁
writeMu sync.Mutex // 写锁
err error // 错误
close chan struct{} // 用于关闭服务
}
type respMessage struct {
response interface{}
ctx *light.Context
respChan chan error
}
初始化单个链接
func newBaseClient(serverName string, options *Options) (*BaseClient, error) {
// 服务发现用
service, err := options.loadBalancing.GetService()
if err != nil {
return nil, err
}
con, err := transport.Client.Gen(service.Protocol, service.Addr)
if err != nil {
return nil, errors.WithStack(err)
}
serialization, ex := codes.SerializationManager.Get(options.serializationType)
if !ex {
return nil, pkg.ErrSerialization404
}
compressor, ex := codes.CompressorManager.Get(options.compressorType)
if !ex {
return nil, pkg.ErrCompressor404
}
// 握手
encrypt, err := cryptology.RsaEncrypt([]byte(options.AUTH), options.rsaPublicKey)
if err != nil {
return nil, err
}
aesKey := []byte(strings.ReplaceAll(uuid.New().String(), "-", ""))
// 交换秘钥
aesKey2, err := cryptology.RsaEncrypt(aesKey, options.rsaPublicKey)
if err != nil {
return nil, err
}
handshake := protocol.EncodeHandshake(aesKey2, encrypt, []byte(""))
_, err = con.Write(handshake)
if err != nil {
con.Close()
return nil, err
}
hsk := &protocol.Handshake{}
err = hsk.Handshake(con)
if err != nil {
con.Close()
return nil, err
}
if hsk.Error != nil && len(hsk.Error) > 0 {
con.Close()
err := string(hsk.Error)
return nil, errors.New(err)
}
bc := &BaseClient{
serverName: serverName,
conn: con,
options: options,
serialization: serialization,
compressor: compressor,
respInterMap: map[string]*respMessage{},
aesKey: aesKey,
close: make(chan struct{}),
}
go bc.heartBeat() // 心跳服务
go bc.processMessageManager() // 返回消息的处理
return bc, nil
}
heartBeat 心跳服务
func (b *BaseClient) heartBeat() {
defer func() {
fmt.Println("heartBeat Close")
}()
loop:
for {
select {
case <-b.close:
break loop
case <-time.After(b.options.heartBeat): // 定时发送心跳
_, i, err := protocol.EncodeMessage("x", []byte(""), []byte(""), []byte(""), byte(protocol.HeartBeat), byte(b.options.compressorType), byte(b.options.serializationType), []byte(""))
if err != nil {
log.Println(err)
break
}
now := time.Now()
b.conn.SetDeadline(now.Add(b.options.writeTimeout))
b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
b.writeMu.Lock()
_, err = b.conn.Write(i)
b.writeMu.Unlock()
if err != nil {
b.err = err
break loop
}
}
}
}
processMessageManager 返回消息的处理服务 (注意这里可以并发的来)
func (b *BaseClient) processMessageManager() {
defer func() {
fmt.Println("processMessageManager Close")
}()
for {
magic, respChan, err := b.processMessage() // 处理某个消息
if err == nil && magic == "" {
continue
}
if err != nil && magic == "" {
break
}
if err != nil && magic != "" && respChan != nil {
respChan <- err
}
if err == nil && magic != "" && respChan != nil {
close(respChan)
}
}
}
func (b *BaseClient) processMessage() (magic string, respChan chan error, err error) {
// 3.封装回执
now := time.Now()
b.conn.SetReadDeadline(now.Add(b.options.readTimeout))
proto := protocol.NewProtocol()
msg, err := proto.IODecode(b.conn)
if err != nil {
b.err = err
close(b.close)
return "", nil, err
}
// heartbeat
if msg.Header.RespType == byte(protocol.HeartBeat) {
if b.options.Trace {
log.Println("is HeartBeat")
}
return "", nil, nil
}
b.respInterRM.RLock()
message, ex := b.respInterMap[msg.MagicNumber]
b.respInterRM.RUnlock()
if !ex { // 不存在 代表消息已经失效
if b.options.Trace {
log.Println("Not Ex", msg.MagicNumber)
}
return "", nil, nil
}
comp, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
if !ex {
return "", nil, nil
}
// 1. 解压缩
msg.MetaData, err = comp.Unzip(msg.MetaData)
if err != nil {
return "", nil, err
}
msg.Payload, err = comp.Unzip(msg.Payload)
if err != nil {
return "", nil, err
}
// 2. 解密
msg.MetaData, err = cryptology.AESDecrypt(b.aesKey, msg.MetaData)
if err != nil {
if len(msg.MetaData) != 0 {
return "", nil, err
}
msg.Payload = []byte("")
}
msg.Payload, err = cryptology.AESDecrypt(b.aesKey, msg.Payload)
if err != nil {
if len(msg.Payload) != 0 {
return "", nil, err
}
msg.Payload = []byte("")
}
// 3. 反序列化 RespError
mtData := make(map[string]string)
err = b.serialization.Decode(msg.MetaData, &mtData)
if err != nil {
return "", nil, err
}
message.ctx.SetMetaData(mtData)
value := message.ctx.Value("RespError")
if value != "" {
return msg.MagicNumber, message.respChan, errors.New(value)
}
return msg.MagicNumber, message.respChan, b.serialization.Decode(msg.Payload, message.response)
}
服务调用
func (b *BaseClient) call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}, respChan chan error) (magic string, err error) {
metaData := ctx.GetMetaData() // 获取ctx 进行基础编码
// 1. 构造请求
// 1.1 序列化
serviceNameByte := []byte(b.serverName)
serviceMethodByte := []byte(serviceMethod)
var metaDataBytes []byte
var requestBytes []byte
metaDataBytes, err = b.serialization.Encode(metaData)
if err != nil {
return "", err
}
requestBytes, err = b.serialization.Encode(request)
if err != nil {
return "", err
}
// 1.2 加密
metaDataBytes, err = cryptology.AESEncrypt(b.aesKey, metaDataBytes)
if err != nil {
return "", err
}
requestBytes, err = cryptology.AESEncrypt(b.aesKey, requestBytes)
if err != nil {
return "", err
}
compressorType := b.options.compressorType
if len(metaDataBytes) > compressorMin && len(metaDataBytes) < compressorMax {
// 1.3 压缩
metaDataBytes, err = b.compressor.Zip(metaDataBytes)
if err != nil {
return "", err
}
requestBytes, err = b.compressor.Zip(requestBytes)
if err != nil {
return "", err
}
} else {
compressorType = codes.RawData
}
// 1.4 封装消息
magic, message, err := protocol.EncodeMessage("", serviceNameByte, serviceMethodByte, metaDataBytes, byte(protocol.Request), byte(compressorType), byte(b.options.serializationType), requestBytes)
if err != nil {
return "", err
}
// 2. 发送消息
if b.options.writeTimeout > 0 {
now := time.Now()
timeout := ctx.GetTimeout() // 如果ctx 存在设置 则采用 返之使用默认配置
if timeout > 0 {
b.conn.SetDeadline(now.Add(timeout))
b.conn.SetWriteDeadline(now.Add(timeout))
} else {
b.conn.SetDeadline(now.Add(b.options.writeTimeout))
b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
}
}
// 写MAP
b.respInterRM.Lock()
b.respInterMap[magic] = &respMessage{
response: response,
ctx: ctx,
respChan: respChan,
}
b.respInterRM.Unlock()
// 有点暴力呀 直接上锁
b.writeMu.Lock()
_, err = b.conn.Write(message)
b.writeMu.Unlock()
if err != nil {
if b.options.Trace {
log.Println(err)
}
b.err = err
return "", errors.WithStack(err)
}
return magic, nil
}
以上是关于用GO写一个RPC框架 s05 (客户端编写)的主要内容,如果未能解决你的问题,请参考以下文章