go-sql-driver源码分析

Posted golang算法架构leetcode技术php

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了go-sql-driver源码分析相关的知识,希望对你有一定的参考价值。

https://github.com/go-sql-driver 实现了基本的sql操作

https://github.com/jmoiron/sqlx  实现了增强版的复杂sql操作


1)database/sql 

定义了对数据库的一系列操作,只是定义了一些列的规范,但是没有提供任何官方的数据库驱动,所以我们需要第三方数据库驱动

 

2)_ "github.com/go-sql-driver/mysql"

 

3)DB: 是一个数据库(操作)句柄,代表一个具有多个底层连接的连接池,它可以安全的被多个go程同时使用,sql包会自动创建和释放连接,

root jianan test

 

4)数据库操作

   Exec: 增删改

   Query: 查询多条数据结果

   QueryRow: 最多返回一行

   Next: 一行行的从结果集中扫描

   Prepare: 预编译,预处理


go连接mysql为什么需要 import _ "github.com/go-sql-driver/mysql"


go中import _的作用只执行引入包的init函数,那么go-sql-driver/mysql 的init函数又做了什么,在database/sql 中的drivers map[string]driver.Driver注册引擎 mysql => MySQLDriver{}


// go-sql-driver/mysql/driver.go

func init() {

sql.Register("mysql", &MySQLDriver{})

}

SetMaxIdleConns/SetMaxOpenConns/SetConnMaxLifetime 设置的值有什么用呢?


db.maxLifetime 连接从创建开始存活的时间,mysql默认tcp连接的超时时间 8h

db.maxOpen 打开的连接最大数量,超过该数量后,query会被阻塞等待可用连接

db.maxIdle 空闲池维持的最大连接数量

sql.Open为什么只需要一次调用即可?


加载驱动程序 go-sql-driver/mysql

初始化DB数据结构

构造创建连接channel/重置连接channel

这里并有实际去和数据建立连接,也没有对数据库连接的参数校验,只是初始化了DB结构体,DB结构体中已经包含了连接池 freeConn []*driverConn, 没有必要多次调用open,全局维护一个DB即可,如果需要验证 账户密码/网络是否能通信,需要调用ping来检测。


type DB struct {

waitDuration int64 // 统计每次链接的等待时间


connector driver.Connector


// 已经关闭的连接数量

numClosed uint64


mu sync.Mutex


freeConn     []*driverConn               // 空闲池

connRequests map[uint64]chan connRequest // 等待连接队列,当前连接达到maxOpen时,就无法不在创建连接,创建一个请求的channel入队列并等待并阻塞当前goroutine

nextRequest  uint64                      // connRequests[] 指向下一个坑位

numOpen      int                         // 正在使用的连接数量


// channel 通过此channel来接受建立新的session连接

// DB.connectionOpener for select 接收channel

openerCh          chan struct{}

resetterCh        chan *driverConn       // 重置session以便其他query复用当前session

closed            bool                   // 标记是否已经关闭链接

dep               map[finalCloser]depSet //

lastPut           map[*driverConn]string // 调试用

maxIdle           int                    // 维持的最大空闲连接数,小于等于0使用defaultMaxIdleConns(2)

maxOpen           int                    // 维持的最大连接数,0无限制

maxLifetime       time.Duration          // 连接可以重用的最长时间

cleanerCh         chan struct{}

waitCount         int64 // 等待连接的总数

maxIdleClosed     int64 // 由于空闲而连接的总数

maxLifetimeClosed int64 // 连接存活时间超过maxLifetime而关闭的时间


stop func() // 取消接收 建立连接channel(openerCh)/重置session channel(resetterCh)

}

如果获取数据连接池的状态?当前的连接总数,SetMaxOpenConns/SetMaxIdleConns的数量是否合适。

使用db.Stats可以查看当前连接池的一些状态,这边返回了一个DBStats结构体,一起看下:

type DBStats struct {

MaxOpenConnections int // 打开的最大连接数,包含已经关闭了的连接


// 连接池状态

OpenConnections int // 当前建立连接的数量,包括正在使用和空闲的数量

InUse           int // 正在使用的连接数

Idle            int // 已建立连接,空闲中的连接数


// 统计

WaitCount         int64         // 等待连接的总数,这里需要着重关注一下

WaitDuration      time.Duration // query持续等待连接的总时长

MaxIdleClosed     int64         // 达到SetMaxIdleConns,而关闭的连接数量

MaxLifetimeClosed int64         // 达到SetConnMaxLifetime. 而关闭的连接数量

}

WaitCount 着重关注一下,看下是不是有慢查询阻塞了可用的连接数量

MaxIdleConns设置的过小,而请求数过多导致,会导致MaxIdleClosed比较大,WaitCount也会比较大

注意:db.Stats不要调用过于频繁,它会对整个DB连接池加锁,过于频繁有一定开销。


DB.Ping()做了哪些事情?DB是如何从连接池中获取一个可用的连接的?


获取连接,可以复用连接 (cachedOrNewConn)

获取一个可用的连接 driverConn

复用空闲池中freeConn已有的连接

从空闲池中移除第一个连接conn

这期间都是有锁的,freeConn是一个切片,是并发不安全的

该连接是否在生命周期内 lifetime – db.SetLifeTime设置的tcp连接存活时间

本连接过了生命周期,返回 driver.ErrBadConn

返回此conn

db.SetMaxOpenConns设置了最大打开的连接数,且当前打开的连接已经达到最大数

创建一个等待请求,放入等待队列,阻塞当前goroutine

等待超时使用context取消,或者等待直到获取可用的连接

ctx取消后还是获取到了连接,放回空闲池

获取到可用连接,统计本次阻塞时长,可以注意到如果DB.Stats().WaitDuration大了以后问题就很严重了

如果本连接过了生命周期,返回 driver.ErrBadConn

返回此Conn

封装driver返回的connect到driverConn

标记driverConn inUse使用中

记录连接创建时间createdAt

db指向连接池

真正底层的连接

源码小细节:


移除切片第一个元素:copy(db.freeConn, db.freeConn[1:])


删除map中的元素: delete(db.connRequests, reqKey)


map[key]interface: 也可以使用channel做为key


生命周期判断:createdAt.Add(lifetime).Before(time.now())


参观一下context的用法


select {

// 这里留取消的口

case <-ctx.Done():

  select {

  // 之前我们分析过select尽量不要加default,单那是for select结构,会造成自旋锁,长期占用M不释放

  // 如果这里不用default它就阻塞一直等待req channel中有connect,这里并不是为了等待,只是为了清理一下channel的connect,防止孤儿connect

  default:

case ret, ok := <-req:

//...

  }

  return nil, ctx.Err()

case ret, ok := <-req:

  // ...

  return ret.conn, ret.err

}


// 我们在外面如何控制超时呢?

fun bar(){

  t := time.After(10)

  ctx := context.Background()

  res := make(chan struct{})

  go func() {

    ci, _ := conn(ctx)

    res<-ci

  }()

  select {

  case <-t:

    ctx.Done()

  case ci := <-res:

    if ci != nil {


    }

  }

}

Buffer.go

buffer 是一个用于给 数据库连接 (net.Conn) 进行缓冲的一个数据结构,其结构为:

type buffer struct {
buf []byte // 缓冲池中的数据
nc net.Conn // 负责缓冲的数据库连接对象
idx int // 已读数据索引
length int // 缓冲池中未读数据的长度
timeout time.Duration // 数据库连接的超时设置
}

可以看到,因为  数据库连接 (net.Conn)  在通信的时候是 同步 的。而为了让其能够 同时 读/写 ,所以实现了 buffer 这个数据结构,通过该 buffer 进行数据缓冲还能实现 零拷贝 ( zero-copy-ish ) 。

其函数分别有:

  • newBuffer(nc net.Conn) buffer :创建并返回一个 buffer

  • (*buffer) readNext(need int) ([]byte, error) :读取并返回未读数据的 need 位,如果 need 大于 bufferlength ,就会调用 fill(need int) errorbuffer进行 扩容 。

  • (*buffer) fill(need int) error :对 buffer 进行 (need/defaultBufSize) 的倍数扩容,并在 timeout 时间结束前从 buffer.nc 中读取 need 长度的数据。

  • (*buffer) takeBuffer(length int) []byte :读取 bufferlength 长度的数据(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil 。如果需要读取的长度大于 buffer 的容量,则会进行扩容。

  • (*buffer) takeSmallBuffer(length int) []byte :读取保证不超过 defaultBufSize 长度的数据的快捷函数(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil

  • (*buffer) takeCompleteBuffer() []byte :读取全部的 buffer 数据(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil

Collations.go

collations 包含了 MySQL 所有支持的 字符集 格式,并支持通过 COLLATION_NAME 返回其字符集 ID

如果需要查询 MySQL 支持的 字符集 格式,可以使用 SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS 语句获取。

Dsn.go

DSN数据源名称 (Data Source Name)  ,是 驱动程序连接数据库的变量信息 ,简而言之就是根据你连接的不同数据库使用对应的连接信息。

通常,数据库的连接配置就是在这里定义的:

// Config 基本的数据库连接信息
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout

AllowAllFiles bool // 允许文件使用 LOAD DATA LOCAL INFILE 导入数据库
AllowCleartextPasswords bool // 支持明文密码客户端
AllowOldPasswords bool // 允许使用不可靠的旧密码
ClientFoundRows bool // 返回匹配的行数而不是受影响的行数
ColumnsWithAlias bool // 将表名前置在列名
InterpolateParams bool // 将占位符插入查询的SQL字符串
MultiStatements bool // 允许一条语句多次查询
ParseTime bool // 格式化时间值为 time.Time 变量
Strict bool // 将 warnings 返回 errors
}

这都是一些常见的配置项,就此略过。

该文件有两个公共函数支持 ConfigDSN 之间转换。

  • (*Config)FormatDSN() string

  • ParseDSN(dsn string) (*Config, error)

Errors.go

errors 定义了 LoggerMySQLErrorMySQLWarning 等数据结构。

Logger

复用了 Go 原生的 log 包,并将其中的输出重定向至控制台的 标准错误 。

type Logger interface {
Print(v ...interface{})
}

var errLog = Logger(log.New(os.Stderr, "[mysql]", log.Ldate|log.Ltime|log.Lshortfile))

func SetLogger(logger Logger) error { // 当然,你也可以使用自定义的错误 Logger
if logger == nil {
return errors.New("logger is nil")
}
errLog =logger
return nil
}

MySQLError

MySQLError 则简单定义了 MySQL 输出的错误的结构。

type MySQLError struct {
Number uint16
Message string
}

MySQLWarning

MySQLWarning 则有些不一样,它需要从 MySQL 中进行一次 查询 ,以获取所有的警告信息,所以该包也定义了 MySQLWarningslice 结构。

type MySQLWarning struct {
Level string
Code string
Message string
}

type MySQLWarnings []MySQLWarning

func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", nil)
// handle err

// initzation MySQLWarnings

for {
err = rows.Next(values)
switch err {
case nil:
warning := MySQLWarning{}

if raw, ok := values[0].([]byte); ok {
warning.Level = string(raw)
}else {
warning.Level = fmt.Sprintf("%s", values[0])
}

if raw, ok := values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}

if raw, ok := values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}

warnings = append(warnings, warning)
}

case io.EOF:
return warnings

default:
rows.Close() // 值得注意的是,如果该函数没有 case 运行 default ,该 rows 就不会被默认关闭,就会占用连接池中的一个连接,是否应该使用 `defer rows.Close() ` 避免该情况?
return
}
}

Infile.go

前面也有提到 MySQL 在导入大型文件的时候,需要使用 LOAD DATA LOCAL INFILE 的形式进行导入,而该 infile.go 就是实现该协议的代码。

本包在实现的 LOAD DATA 的时候提供了两种方式进行导入:

  • 最常见的,使用服务器的文件路径,如 /data/students.csv ,下文命名其为 文件路径注册器

  • 最通用的,使用实现了 io.Reader 接口的数据结构,通过返回该数据结构的数据进行导入,如 bytes os.file 等,下文命名其为 Reader 接口注册器

在实现该功能的时候,注册器 的实现是用名字作为 Key 的 Map ,为了避免 Map读写竞态 ,需要对其配置一个读写锁。

var (
fileRegister map[string]bool // 文件路径注册器
fileRegisterLock sync.RWMutex // 文件路径注册器读写锁
readerRegister map[string]func() io.Reader // Reader 接口注册器
readerRegisterLock sync.RWMutex // Reader 接口注册器读写锁
)

除了对两个注册器的 注册 以及 注销 函数,还有一个需要分析的一个函数:

(mc *mysqlConn) handleInFileRequest(name string) (err error)

通过传入 文件路径 或者 Reader 名称 就可以将数据发往 MySQL 了。

func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
packSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
if mc.maxWriteSize < packSize { // 设置发往 MySQL 的数据块大小
packSize = mc.maxWriteSize
}

// 获取 文件 或 Reader 的数据,并将其赋值到 rdr 中
// var rdr io.Reader

// send context packets
if err != nil {
data := make([]byte, 4+packetSize) // 需要留 4 个 byte 给协议使用
var n int
for err == nil {
n, err = rdr.Read(data[4:]) // 将数据存入 data 的 [4:] 中
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { // 将 data 数据发往 MySQL
return ioErr
}
}
}
if err == io.EOF { // rdr 中的数据读完了
err = nil
}
}

// send empty packet (termination)
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil { // 告诉 MySQL 文件发送完毕
return ioErr
}

// read OK packet
if err == nil { // 一切正常结束
return mc.readResultOK()
}

mc.readPacket() // 如果中途出错,将错误信息读取到 mysqlConn 中,并返回该错误
return err
}

到此,infile.go 的实现已经整理完毕了,可以看到, 作者 在实现这个功能的时候还是做了一些优化的,比如 map Lazy initsend packet size limited 等。而我们通过分析规范的源码包,能够提升自己的编码水平。

Packets.go

接下来就要深入到 MySQL 的通信协议中了,官方的 通信协议文档 非常齐全,我在这里只将一些基础的,我后面分析源码会用到的协议分析下,如果有兴趣,可以到官方文档处进行查阅。

Protocol Basics

基础数据类型

MySQL 通信的基本数据类型有两种, IntegerString

  • Integer : 分别有 1, 2, 3, 4, 8 个字节长度的类型,使用小端传输。

  • String : 分别有 固定长度字符串(协议规定),NULL结尾字符串(长度不固定),长度编码字符串(长度不固定)。

报文协议

报文分为 消息头 以及 消息体,而 消息头 由 3 字节的 消息长度 以及 1 字节的 序号  sequence (新客户端由 0 开始)组成,消息体 则由 消息长度 的字节组成。

  • 3 字节的 消息长度 最大值为 0xFFFFFF ,即为 16 MB - 1 byte ,这就意味着,如果整个消息(不包括消息头)的长度大于 16MB - 1byte - 4byte 大小时,消息就会被分包。

  • 1 字节的 序号 在每次新的客户端发起请求时,以 0 开始,依次递增 1 ,如果消息需要分包, 序号 会随着分包的数量递增。而在一次应答中, 客户端会校验服务器 返回序号 是否与 发送序号 一致,如果不一致,则返回错误异常。

协议类型

  • handshake : 发起连接

  • auth : 登录权限校验

  • ok | error : 返回结果状态 *

  • ok : 首字节为 0 (0x00

  • error : 首字节为 255 (0xff

  • resultset : 结果集

  • header

  • field

  • eof

  • row

  • command package : 命令

在整个 MySQL 发起交互的过程如下图所示:

mysql connect

在了解这些 MySQL 基础协议知识后,我们再来看 packages.go 的源码就轻松多了。

源码

先来看看 readPacket ,结合上面的知识点应该非常好理解。

func (mc *mysqlConn) readPacket() ([]byte, error) {
var payload []byte
for { // for 循环是为了读取有可能分片的数据
// Read package header
data, err := mc.buf.readNext(4) // 从 buffer 缓冲器中读取 4 字节的 header
if err != nil { // 如果读取发生异常,则关闭连接,并返回一个错误连接的异常
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
}

// Packet Length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // 读取 3 字节的消息长度

if pktLen < 1 {
// 如上所示,关闭连接,并返回一个错误连接的异常
}

// Check Packet Sync [8 bit]
if data[3] != mc.sequence { // 判断服务端返回的序号是否与客户端一致
if data[3] > mc.sequence {
return nil, ErrPktSyncMul // 如果服务端返回序号大于客户端的序号,则有可能是在一次请求中做了多次操作
}
return nil, ErrPktSync // 返回序号不一致错误
}
mc.sequence++ // 本次序号匹配相符,为了匹配下一次请求,先将序号自增1

data, err := mc.buf.readNext(pktLen) // 读取 消息长度 的数据
if err != nil {
// 如上所示,关闭连接,并返回一个错误连接的异常
}

isLastPacket := (pktLen < maxPacketSize) // 如果是最后一个数据包,必然小于 maxPacketSize (16MB - 1byte)

// Zero allocations for non-splitting packets
if isLastPacket && payload == nil { // 无分包情况,立即返回
return data, nil
}

payload = append(payload, data...)

if isLastPacket { // 如果是最后一个包,读取完毕后返回
return payload, nil
}

// 还有未读数据,开始下一次循环
}
}

下面来看下结合 握手报文协议 来看下客户端向服务端发起请求的 readInitPacket

go-sql-driver源码分析

mysql handshack protocol

func (mc *mysqlConn) readInitPacket() ([]byte, error) {
data, err := mc.readPacket() // 调用上面的函数读取服务端返回的数据
if err != nil {
return nil, err
}

if data[0] == iERR { // iERR = 0xff 消息体的第一个字节返回 0xff ,则意味着 error package
return nil, mc.handleErrorPacket(data)
}

// protocol version [1 byte]
if data[0] < minProtocolVersion { // 判断是否是兼容的协议版本
return nil, fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
)
}

// server version [null terminated string]
// connection id [4 bytes]
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // 读取 NULL (0x00)为结尾的字符串,跳过服务器线程 ID

// first part of the password cipher [8 bytes]
cipher := data[pos : pos+8] // 获取挑战随机数

// (filler) always 0x00 [1 byte]
pos += 8 + 1

// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) // 获取服务器权能标识
if mc.flags&clientProtocol41 == 0 { // 说明 MySQL 服务器不支持高于 41 版本的协议
return nil, ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { // 说明 MySQL 服务器需要 SSL 加密,但是客户端没有配置 SSL
return nil, ErrNoTLS
}
pos += 2 // 指针向后两位

if len(data) > pos {
// 指针跳过标志位
pos += 1 + 2 + 2 + 1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
//
// The web documentation is ambiguous about the length. However,
// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
// the 13th byte is " byte, terminating the second part of
// a scramble". So the second part of the password cipher is
// a NULL terminated string that's at least 13 bytes with the
// last byte being NULL.
//
// The official Python library uses the fixed length 12
// which seems to work but technically could have a hidden bug.
cipher = append(cipher, data[pos:pos+12]...)

// TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
// NUL otherwise
//
//if data[len(data)-1] == 0 {
// return
//}
//return ErrMalformPkt

// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], cipher)
return b[:], nil
}

// make a memory safe copy of the cipher slice
var b [8]byte // 返回 8 字节的挑战随机数
copy(b[:], cipher)
return b[:], nil
}

除了上面解析的两个函数, packages.go 还有 initialisation process / result packages / prepared statements 等协议的 写入/读取 ,有兴趣的读者可以结合上面的知识点自行阅读。

Driver.go

接下来就要分析一些比较重要的代码了,比如接下来要讲的 driver.go ,它主要负责与 MySQL 数据库进行各种协议的连接,并返回该连接。可以说它才是最基础、最核心的功能。

不过首先我们需要看下 database/sql 包中的 Driver 接口需要如何实现:

// database/sql/driver/driver.go

// 数据库驱动
type Driver interface {
Open(name string) (Conn, error)
}

// ...

// 非并发安全数据库连接
type Conn interface {
// 返回一个绑定到 sql 的准备语句
Prepare(query string) (Stmt, error)

// 关闭该连接,并标记为不再使用,停止所有准备语句和事务
// 因为 database/sql 包维护了一个空闲的连接池,并且在空闲连接过多的时候会自动调用 Close ,所以驱动程序包不需要显式调用该函数
Close() error

// 开始并返回一个新的事务,而新的事务与旧的连接没有任何关联
Begin() (Tx, error)
}

根据 database/sql 提供的 Driver 接口, go-sql-driver/mysql 实现了自己的 数据库驱动 结构:

type MySQLDriver struct{}

func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc := &mysqlConn {
// set max value
}
mc.cfg = ParseDSN(dsn) // 通过解析 DSN 设置 MySQL 连接的配置

// set parseTime and strict
// ...

// connect to server
if dial, ok := dials[mc.cfg.Net]; ok { // 根据 地址 以及 协议类型,尝试连接上服务器
mc.netConn, err = dial(mc.cfg.Addr)
} else { // 连接服务器失败,尝试重连
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err := nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil { // 重试失败,返回异常
return nil, err
}

// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.Conn); ok { // tcp 连接类型转换
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close() // 如果设置长连接失败,返回异常之前一定要记得将连接断开
mc.netConn = nil
return nil, err
}
}

mc.buff = newBuff(mc.netConn) // 生成一个带缓冲的 buffer,如上面 buffer.go 中所说

// set I/O timeout
// ...

// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket() // 发起数据库首次握手
if err != nil {
mc.cleanup() // 将当前 mysqlConn 对象销毁,后面我们会说这个函数
return nil, err
}

// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil { // 向数据库发送登录信息校验
mc.cleanup()
return nil, err
}
}

connection.go

终于要讲到这个包的核心数据结构 mysqlConn 了,可以说,驱动的所有功能几乎都围绕着这个数据结构,我们先来看看它的结构:

type mysqlConn struct {
buf buffer // buffer 缓冲器
netConn net.Conn // 网络连接
affectedRows uint64 // sql 执行成功影响行数
insertId uint64 // sql 添加成功最新的主键 ID
cfg *Config // dsn 中的 基础配置
maxPacketAllowed int // 允许的最大报文的字节长度,最大不能超过 (16MB - 1byte)
maxWriteSize int // 允许最大的写入字节长度,最大不能超过 (16MB - 1byte)
writeTimeout time.Duration // 执行 sql 的 超时时间
flags clientFlag // 客户端状态标识
status statusFlag // 服务端状态标识
sequence uint8 // 序号
parseTime bool // 是否格式化时间
strict bool // 是否使用严格模式
}

// driver.go
// 而创建一个 mysqlConn 连接需要通过 driver.go 中的 Open 函数,也说明 mysqlConn 实现了 driver.Conn 接口
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc := &mysqlConn{
// ...
}

// ...

return mc, nil
}

当一个新的客户端连接上服务器的时候 (三次握手结束,客户端进入 established 状态),需要先对 MySQL 服务器进行 会话的用户/系统环境变量 的设置。

// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.Params { // Params: map[string]string
switch param {
// Charset
case "charset": // 如果是字符集,则调用 SET NAMES 命令
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}

// System Vars
default: // 执行系统环境变量设置
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
}

conntion.go  还负责 事务 、预处理语句 、执行/查询 的管理,但是基本都是往 mysqlConn 中发送  command package ,如:

// Begin 开启事务
func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
err := mc.exec("START TRANSACTION")
if err == nil {
return &mysqlTx{mc}, err // 返回成功开启的事务,重用之前的连接
}

return nil, err
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
err := mc.writeCommandPacketStr(comQurey, query)
if err != nil {
return err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket() // 根据 data[0] 的值判断是否出错,如果没有错误,则返回消息体的长度
if err == nil && resLen > 0 { // 存在有效消息体
if err = mc.readUntilEOF(); err != nil { // 读取 columns
return err
}

err = mc.readUntilEOF() // 读取 rows
}

return err
}

我想 conntion.go 中最重要的一个函数应该是 cleanup ,它负责将 连接关闭 、 重置环境变量 等功能,但是该函数不能随意调用,它只有在 登录权限校验异常 时候才应该被调用,否则服务器在不知道客户端 被强行关闭 的情况下,依然会向该客户端发送消息,导致严重异常:

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
// Makes cleanup idempotent 保证函数的幂等性
if mc.netConn != nil {
if err := mc.netConn.Close(); err != nil { // Close 会尝试发送 comQuit command 到服务器
errLog.Print(err)
}
mc.netConn = nil // 不管 Close 是否成功,必须将 netConn 清空
}
mc.cfg = nil
mc.buf.nc = nil // 缓冲器中的 netConn 也要关闭
}

Result.go

每当 MySQL 返回一个 OK状态报文 ,该报文协议会携带上本次执行的结果 affectedRows 以及 insertId ,而 result.go 就包含着一个数据结构,用于存储本次的执行结果。

type mysqlResult struct {
affectedRows int64
insertId int64
}

// 两个 getter
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
}

func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
}

接下来我们看下在 conntion.go 中是怎么生成 mysqlResult 对象的:

// connect.go
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {

// ...

err := exec(query)
if err == nil {
return &mysqlResult{ // 返回执行的结果
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}

// exec 函数的解析可以返回上面 package.go 中浏览

// package.go
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket()
if err == nil {
switch data[0] {

case iOK:
return 0, mc.handleOkPacket(data) // 处理 OK 状态报文

// ...
}

func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int

// 0x00 [1 byte]

// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])

// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])

// ...
}

Row.go

MySQL 执行 插入、更新、删除 等操作后,都会返回 Result ,但是 查询 返回的是 Rows ,我们先来看看 go-mysql-driver 驱动所实现的 接口  Rows 的接口描述:

// database/sql/driver/driver.go
// Rows 是执行查询返回的结果的 游标
type Rows interface {
// Columns 返回列的名称,从 slice 的长度可以判断列的长度
// 如果一个列的名称未知,则为该列返回一个空字符串
Columns() []string

// Close 关闭游标
Close() error

// Next 将下一行数据填充到 desc 切片中
// 如果读取的是最后一行数据,应该返回一个 io.EOF 错误
Next(desc []Value) error
}

type Value interface{} // Value is a value that drivers must be able to handle.

为什么我要说这是 go-mysql-driver 驱动所实现的 接口 Rows 呢?眼尖的同学应该已经看到了, Next 函数好像和我们平常见到的不一样啊!!

是的,因为我们平常使用的:

  • rows.Next()

  • rows.Scan(dest ...interface{}) error

等函数的对象 rows 并不是上面的 接口描述 Rows ,而是另一个封装的 同名数据结构 Rows ,它就在 database/sql 包中 :

// database/sql.go
type Rows struct {
dc *driverConn
releaseConn func(error)
rowsi driver.Rows // 接口描述的 Rows 藏在这!!!

// 忽略其他字段,因为我们不分析这个包...

// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not not be called concurrently.
lastcols []driver.Value
}

我们跳过 database/sql  包中的 Rows 实现,其无非是提供了更多功能的一个结果集而已,让我们回到真正与数据库进行交互的 Rows 中进行源码分析。

go-sql-driver 实现的 mysqlRows 数据结构只实现了 Columns()Close() 两个行数,剩下的 Next(desc []driver.Value) 实现则交给了 MySQL 的两种结果集协议:

// rows.go

type mysqlField struct {
tableName string
name string
flags fieldFlag
fieldType byte
decimals byte
}

type mysqlRows struct {
mc *mysqlConn
columns []mysqlField
}

type binaryRows struct { // 二进制结果集协议
mysqlRows // 对于 Go 的 组合特性 应该不会陌生吧?
}

type textRows struct { // 文本结果集协议
mysqlRows
}

func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))

// 将列名赋值到 columns ,如果有设置别名则赋值别名...

return columns
}

func (rows *mysqlRows) Close() error {
// 将连接里面的未读数据读完,然后将连接置空
}

// 接下来的 Next 函数实现就交由 binaryRows 和 textRows 了
func (rows *binaryRows) Next(desc []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return ErrInvalidConn
}

return rows.readRow(dest) // 读二进制协议结果集
}
return io.EOF
}

func (rows *testRows) Next(desc []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return ErrInvalidConn
}

return rows.readRow(dest) // 读取文本协议
}
return io.EOF
}

可以说,实现了 driver.Rows 接口的只有 binaryRowstestRows ,而他们里面的 readRow(desc) 实现由于都是和协议强相关的代码,就不再解析了。

我们跟着源码可以看到,使用 textRows 的场景在 getSystemVar 以及 Query 中,而使用 binaryRows 的场景在 statement 中,就是我们下一步需要解析的部分。

Statement.go

Prepared Statement ,即预处理语句,他有什么优势呢,为什么 MySQL 要加入它?

  • 执行性能更高:MySQL 会对 Prepared Statement 语句预先进行编译成模板,并将 占位符 替换 参数 的位置,这样如果频繁执行一条参数只有少量替换的语句时候,性能会得到大量提高。可能有同学会有疑问,为什么 MySQL 语句还需要编译?那么可以来参考下这篇 MySQL Prepare 原理 。

  • 传输协议更优:Prepare Statement 在传输时候使用的是 Binary Protocol ,比使用 Text Protocol 的查询具有 传输数据量更小 、 无需转换数据格式 等优势,缓解了 CPU 和 网络 的开销。

  • 安全性更好:由 MySQL Prepare 原理 我们可以知道,Perpare 编译之后会生成 语法树,在执行的时候才会将参数传进来,这样就避免了平常直接执行 SQL 语句 会发生的 SQL 注入 问题。

好了,先来看下 mysqlStmt 的数据结构:

type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
columns []mysqlField // cached from the first query (既然SQL已经预编译好了,返回的结果集列名已经是确定的,所以在收到 PREPARE_OK 之后解析数据后会缓存下来)
}

我们发现,它比 mysqlRows 多了两个成员变量:

  • idMySQL 预处理语句之后,会给该语句分配一个 id 并返回客户端,用于:

  • 客户端提交该 id 给服务器调用对应的预处理语句。

  • paramCount :参数数量,等于 占位符 的个数,用于:

  • 判断传入的参数个数是否与预编译语句中的占位符个数一致。

  • 判断返回的 PREPARE_OK 响应报文是否带有 参数列名 数据。

下面来看看如何创建并使用一个 Prepare Statement

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // 传入需要预编译的 SQL 语句
// 检查连接是否可用...

err = mc.writeCommandPacketStr(comStmtPrepare, query) // 将 SQL 发往数据库进行预编译
if err != nil {
return nil, err
}

stmt := &mysqlStmt{ // 预编译成功,先创建 stmt 对象
mc: mc,
}

// Read Result
columnCount, err := stmt.readPrepareResultPacket() // 从 stmt 的连接读取返回 响应报文
if err == nil {
if stmt.paramCount > 0 { // 如果预编译的 SQL 的有参数
if err = mc.readUntilEOF(); err != nil { // 读取参数列名数据
return nil, err
}
}

if columnCount > 0 { // 返回执行结果的列表个数
err = mc.readUntilEOF() // 读取执行结果的列名数据
}
}

return stmt, err
}

因为是已经预编译好的语句,所以在执行的时候只需要将参数传进去就可以了。

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
// 检查连接是否可用...

err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}

// 读取结果集的行、列数据
}

func(stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount { // 判断传进来的参数和预编译好的SQL参数 个数是否一致
return fmt.Errorf(
"argument count mismatch (got: %d; has: %d)",
len(args),
stmt.paramCount,
)
}

// 读取缓冲器中的数据,如果为空,则返回异常...

// command [1 byte]
data[4] = comStmtExecute

// statement_id [4 bytes] 将预编译语句的 id 转换为 4字节的二进制数据
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)

// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00

// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00

// 将参数按照不同的类型转换为 binary protobuf 并 append 到 data 中...

return mc.writePacket(data)
}

相信看到这里,已经能对看懂源码的 70% 了,剩余的代码都是和协议相关,就留待有兴趣的读者继续研究,这里就不再展开讲了。

Transaction.go

事务是 MySQL 中很重要的一部分,但是驱动的实现却很简单,因为一切的事务控制都已经交由 MySQL 去执行了,驱动所需要做的,只要发送一个 commit 或者 rollbackcommand packet 即可。

type mysqlTx struct {
mc *mysqlConn
}

func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}

func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}


以上是关于go-sql-driver源码分析的主要内容,如果未能解决你的问题,请参考以下文章

Android 逆向整体加固脱壳 ( DEX 优化流程分析 | DexPrepare.cpp 中 dvmOptimizeDexFile() 方法分析 | /bin/dexopt 源码分析 )(代码片段

Android 事件分发事件分发源码分析 ( Activity 中各层级的事件传递 | Activity -> PhoneWindow -> DecorView -> ViewGroup )(代码片段

Golang database/sql源码分析

《Docker 源码分析》全球首发啦!

mysql jdbc源码分析片段 和 Tomcat's JDBC Pool

Android 逆向ART 脱壳 ( DexClassLoader 脱壳 | DexClassLoader 构造函数 | 参考 Dalvik 的 DexClassLoader 类加载流程 )(代码片段