深入源码分析golang之WaitGroup
Posted 互联网打字员
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入源码分析golang之WaitGroup相关的知识,希望对你有一定的参考价值。
什么是sync.WaitGroup
官方文档对其的描述是:WaitGroup等待一组goroutine的任务完成。主goroutine调用添加以设置要等待的goroutine的数量。然后,每个goroutine都会运行并在完成后调用Done。同时,可以使用Wait来阻塞,直到所有goroutine完成。我们来看官网给的一个例子:
1package main
2
3import (
4 "sync"
5)
6
7type httpPkg struct{}
8
9func (httpPkg) Get(url string) {}
10
11var http httpPkg
12
13func main() {
14 var wg sync.WaitGroup
15 var urls = []string{
16 "http://www.golang.org/",
17 "http://www.google.com/",
18 "http://www.somestupidname.com/",
19 }
20 for _, url := range urls {
21 // 增加waitGroup计数
22 wg.Add(1)
23 // 启动goroutine获取url
24 go func(url string) {
25 //等获取url的goroutine完成,将waitGroup计数减1
26 defer wg.Done()
27 // 获取url
28 http.Get(url)
29 }(url)
30 }
31 // 等待所有goroutine完成
32 wg.Wait()
33}
源码剖析
WaitGroup的实现:WaitGroup的数据结构主要包括一个noCopy的辅助字段,以及一个具有复合含义的state1字段。接下来分别来了解下这两个字段的内部逻辑。
noCopy机制:Go中没有原生的禁止拷贝的方式,所以如果有的结构体,你希望使用者无法拷贝,只能指针传递保证全局唯一的话,可以这么干,定义一个结构体叫noCopy,要实现sync.Locker 这个接口。
1type noCopy struct{}
2
3// nocopy 只有在使用 go vet 检查时才能显示错误,编译正常
4func (*noCopy) Lock() {}
5func (*noCopy) UnLock() {}
state1处理:总共分配了12个字节,在这里被设计成三种状态。其中对齐的8个字节作为状态位(state),高32位为记录计数器的数量,低32位为等待goroutine的数量值。其余的4个字节作为信号量存储(sema)。由于操作系统分为32位和64位,64位的原子操作需要64位对齐,但是32位编译器保证不了,于是这里就采用了动态识别当前我们操作的64位数到底是不是在8字节对齐的位置上面。具体见源码state方法:
1// 得到state的地址和信号量的地址
2func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
3 if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
4 // 如果地址是64bit对齐的,数组前两个元素做state,后一个元素做信号量
5 return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
6 } else {
7 // 如果地址是32bit对齐的,数组后两个元素用来做state,它可以用来做64bit的原子操作,第一个元素32bit用来做信号量
8 return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
9 }
10}
Add方法实现:主要操作的state1字段中计数值部分,计数器部分的逻辑主要是通过state(),在上面有提及。每次调用Add方法就会增加相应数量的计数器。如果计数器为零,则释放等待时阻塞的所有goroutine。如果计数器变为负数,请添加恐慌。如果计数器值大于0,说明此时还有任务没有完成,那么调用者就变成等待者,需要加入wait队列,并且阻塞自己。参数可正可负数。
Add方法源码分析如下:
1func (wg *WaitGroup) Add(delta int) {
Add方法调用流程图如下:
2 //获取state1中的状态位和信号量位
3 statep, semap := wg.state()
4 //用来goroutine的竞争检测,可忽略。
5 if race.Enabled {
6 _ = *statep
7 if delta < 0 {
8 race.ReleaseMerge(unsafe.Pointer(wg))
9 }
10 race.Disable()
11 defer race.Enable()
12 }
13 // uint64(delta)<<32 将delta左移32
14 // 因为高32位表示计数器,所以delta左移32位,
15 // 增加到计数位。
16 state := atomic.AddUint64(statep, uint64(delta)<<32)
17 // 当前计数器的值
18 v := int32(state >> 32)
19 // 阻塞的wait goroutine数量
20 w := uint32(state)
21 if race.Enabled && delta > 0 && v == int32(delta) {
22 race.Read(unsafe.Pointer(semap))
23 }
24 // 计数器的值<0,panic
25 if v < 0 {
26 panic("sync: negative WaitGroup counter")
27 }
28 // 当wait goroutine数量不为0时,累加后的counter值和delta相等,
29 // 说明Add()和Wait()同时调用了,所以发生panic,
30 // 因为正确的做法是先Add()后Wait(),
31 // 也就是已经调用了wait()就不允许再添加任务了
32 if w != 0 && delta > 0 && v == int32(delta) {
33 panic("sync: WaitGroup misuse: Add called concurrently with Wait")
34 }
35 // add调用结束
36 if v > 0 || w == 0 {
37 return
38 }
39 // 能走到这里说明当前Goroutine Counter计数器为0,
40 // Waiter Counter计数器大于0,
41 // 到这里数据也就是允许发生变动了,如果发生变动了,则出发panic
42 if *statep != state {
43 panic("sync: WaitGroup misuse: Add called concurrently with Wait")
44 }
45 // 所有的状态位清0
46 *statep = 0
47 for ; w != 0; w-- {
48 // 首先让信号量加一,然后检查是否有正在等待的Goroutine,如果没有,直接返回;
49 // 如果有,调用goready函数唤醒一个Goroutine。
50 runtime_Semrelease(semap, false, 0)
51 }
52
Done方法实现:内部调用了Add(-1)的方法。详情看Add方法
1//Done方法其实就是Add(-1)
2func (wg *WaitGroup) Done() {
3 wg.Add(-1)
4}
Wait方法实现:阻塞主goroutine直到WaitGroup计数器变为0。
Wait方法源码分析如下:
1// 等待并阻塞,直到WaitGroup计数器为0
2func (wg *WaitGroup) Wait() {
3 // 获取waitgroup状态位和信号量
4 statep, semap := wg.state()
5 if race.Enabled {
6 _ = *statep
7 race.Disable()
8 }
9 for {
10 // 使用原子操作读取state,是为了保证Add中的写入操作已经完成
11 state := atomic.LoadUint64(statep)
12 v := int32(state >> 32) //获取计数器(高32位)
13 w := uint32(state) //获取wait goroutine数量(低32位)
14 if v == 0 { // 计数器为0,跳出死循环,不用阻塞
15 if race.Enabled {
16 race.Enable()
17 race.Acquire(unsafe.Pointer(wg))
18 }
19 return
20 }
21 // 使用CAS操作对`waiter Counter`计数器进行+1操作,
22 // 外面有for循环保证这里可以进行重试操作
23 if atomic.CompareAndSwapUint64(statep, state, state+1) {
24 if race.Enabled && w == 0 {
25 race.Write(unsafe.Pointer(semap))
26 }
27 // 在这里获取信号量,使线程进入睡眠状态,
28 // 与Add方法中runtime_Semrelease增加信号量相对应,
29 // 也就是当最后一个任务调用Done方法
30 // 后会调用Add方法对goroutine counter的值减到0,
31 // 就会走到最后的增加信号量
32 runtime_Semacquire(semap)
33 // 在Add方法中增加信号量时已经将statep的值设为0了,
34 // 如果这里不是0,说明在wait之后又调用了Add方法,
35 // 使用时机不对,触发panic
36 if *statep != 0 {
37 panic("sync: WaitGroup is reused before previous Wait has returned")
38 }
39 if race.Enabled {
40 race.Enable()
41 race.Acquire(unsafe.Pointer(wg))
42 }
43 return
44 }
45 }
46}
Wait方法调用流程图如下:
扫码关注
获取更多干货内容
以上是关于深入源码分析golang之WaitGroup的主要内容,如果未能解决你的问题,请参考以下文章
Golang的sync.WaitGroup 实现逻辑和源码解析