AQS源码探究_07 CountDownLatch源码分析

Posted 兴趣使然の草帽路飞

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了AQS源码探究_07 CountDownLatch源码分析相关的知识,希望对你有一定的参考价值。

  • 在学习CountDownLatch之前,最好仔细理解下前面AQS相关的几篇文章,配合着自己搭建的源码环境进行逐行跟踪,防止迷路~

1、CountDownLatch简介

CountDownLatch,是一个简单的同步器,它的含义是允许一个或多个线程等待其它线程的操作执行完毕后再执行后续的操作

CountDownLatch的通常用法和Thread.join()有点类似,等待其它线程都完成后再执行主任务。


2、入门案例分析

案例1:

  • 对于像我一样的学生来说,CountDwonLatch的实际开发应用很少,甚至有同学没有接触过它。但是在并发条件下,这个类的使用还是很常见的,所以先引入2个案例去了解下它的用途:
  • 借助CountDownLatch,控制主线程等待子线程完成再执行
/**
 * date: 2021/5/7 10:01
 * @author csp
 */
public class CountDownLatchTest01 {
    private static final int TASK_COUNT = 8;
    private static final int THREAD_CORE_SIZE = 10;

    public static void main(String[] args) throws InterruptedException {
        // 实例化CountDownLatch,指定初始计数值为TASK_COUNT(8)
        CountDownLatch latch = new CountDownLatch(TASK_COUNT);
        // 通过Executors创建一个初始容量为THREAD_CORE_SIZE(10)的线程池
        // (注意:在阿里巴巴开发手册中,建议不要使用Executors直接去创建线程池,
        // 要使用其内部调用的ThreadPoolExecutor去手动设置线程池的相关参数,并创建线程池)
        Executor executor = Executors.newFixedThreadPool(THREAD_CORE_SIZE);

        // 依次向线程池中投入8个执行的线程
        for(int i = 0; i < 8; i++) {
            // i -> taskId 任务id
            // latch -> 同步计数器的值
            executor.execute(new WorkerRunnable(i, latch));
        }

        System.out.println("主线程等待所有子任务完成....");
        long mainWaitStartTimeMillis = System.currentTimeMillis();
        latch.await();
        long mainWaitEndTimeMillis = System.currentTimeMillis();
        System.out.println("主线程等待时长:" + (mainWaitEndTimeMillis - mainWaitStartTimeMillis));
    }
    
    /**
     * 工作线程
     */
    static class WorkerRunnable implements Runnable {
        /**
         * 任务id
         */
        private int taskId;

        /**
         * CountDownLatch同步计数器
         */
        private CountDownLatch latch;

        @Override
        public void run() {
            doWorker();
        }

        /**
         * 工作方法
         */
        public void doWorker() {
            System.out.println("任务ID:" + taskId + ",正在执行任务中....");
            try {
                // 休眠5s,模拟正在处理任务
                TimeUnit.SECONDS.sleep(5);
            } catch (InterruptedException e) {
            } finally {
                // latch = latch-1 :
                // 计数器的值latch开始是TASK_COUNT,每执行完一个doWorker方法就-1
                // 直到latch值减小为0,才能继续执行latch.await();之后的方法
                latch.countDown();
            }
            System.out.println("任务ID:" + taskId + ",任务执行结束!");
        }

        public WorkerRunnable(int taskId, CountDownLatch latch) {
            this.taskId = taskId;
            this.latch = latch;
        }
    }
}

运行结果如下:

主线程等待所有子任务完成....
任务ID:0,正在执行任务中....
任务ID:1,正在执行任务中....
任务ID:2,正在执行任务中....
任务ID:4,正在执行任务中....
任务ID:5,正在执行任务中....
任务ID:3,正在执行任务中....
任务ID:6,正在执行任务中....
任务ID:7,正在执行任务中....
任务ID:0,任务执行结束!
任务ID:5,任务执行结束!
任务ID:3,任务执行结束!
任务ID:4,任务执行结束!
任务ID:2,任务执行结束!
任务ID:1,任务执行结束!
任务ID:7,任务执行结束!
任务ID:6,任务执行结束!
主线程等待时长:5000

案例2:

  • 执行任务的线程,也可能是多对多的关系:本案例就来了解一下,借助CountDownLatch,使主线程控制子线程同时开启,主线程再去阻塞等待子线程结束!
/**
 * date: 2021/5/7 10:01
 * @author csp
 */
public class CountDownLatchTest02 {

    // 主线程
    public static void main(String[] args) throws InterruptedException {
        // 开始信号:CountDownLatch初始值为1
        CountDownLatch startSignal = new CountDownLatch(1);
        // 结束信号:CountDownLatch初始值为10
        CountDownLatch doneSignal = new CountDownLatch(10);

        // 开启10个线程,
        for(int i = 0; i < 10; i++) {
            new Thread(new Worker(i, startSignal, doneSignal)).start();
        }

        // 这里让主线程休眠500毫秒,确保所有子线程已经启动,并且阻塞在startSignal栅栏处
        TimeUnit.MILLISECONDS.sleep(500);

        // 因为startSignal 栅栏值为1,所以主线程只要调用一次countDown()方法
        // 那么所有调用startSignal.await()阻塞的子线程,就都可以通过栅栏了
        System.out.println("子任务栅栏已开启...");
        startSignal.countDown();


        System.out.println("等待子任务结束...");
        long startTime = System.currentTimeMillis();
        // 等待所有子任务结束,主线程再继续往下执行
        doneSignal.await();
        long endTime = System.currentTimeMillis();
        System.out.println("所有子任务已经运行结束,耗时:" + (endTime - startTime));
    }

    /**
     * 工作线程:子线程
     */
    static class Worker implements Runnable {
        /**
         * 开启信号
         */
        private final CountDownLatch startSignal;
        /**
         * 结束信号
         */
        private final CountDownLatch doneSignal;
        /**
         * 任务id
         */
        private int id;

        @Override
        public void run() {
            try {
                // 为了让所有线程同时开始任务,我们让所有线程先阻塞在这里(相当于一个栅栏)
                // 等到startSignal值被countDown为0时才往下继续执行:等大家都准备好了,再打开这个门栓
                startSignal.await();
                System.out.println("子任务-" + id + ",开启时间:" + System.currentTimeMillis());
                // sleep 5秒,模拟线程处理任务
                doWork();
            } catch (InterruptedException e) {
            }finally {
                doneSignal.countDown();
            }
        }

        private void doWork() throws InterruptedException {
            TimeUnit.SECONDS.sleep(5);
        }

        public Worker(int id, CountDownLatch startSignal, CountDownLatch doneSignal) {
            this.id = id;
            this.startSignal = startSignal;
            this.doneSignal = doneSignal;
        }
    }
}

执行结果:

子任务栅栏已开启...
等待子任务结束...
子任务-9,开启时间:1620432037554
子任务-8,开启时间:1620432037554
子任务-3,开启时间:1620432037554
子任务-4,开启时间:1620432037554
子任务-1,开启时间:1620432037554
子任务-0,开启时间:1620432037554
子任务-5,开启时间:1620432037554
子任务-7,开启时间:1620432037554
子任务-2,开启时间:1620432037554
子任务-6,开启时间:1620432037554
所有子任务已经运行结束,耗时:5002

上面代码中startSignal.await();就相当于一个栅栏,把所有子线程都抵挡在他们的run方法,等待主线程执行startSignal.countDown();,即关闭栅栏之后,所有子线程在同时继续执行他们自己的run方法,如下图:

请添加图片描述


案例3:

/**
 * date: 2021/5/8
 *
 * @author csp
 */
public class CountDownLatchTest03 {
    public static void main(String[] args) {
        // 声明CountDownLatch计数器,初始值为2
        CountDownLatch latch = new CountDownLatch(2);

        // 任务线程1:
        Thread t1 = new Thread(() -> {
            try {
                Thread.sleep(5000);
            } catch (InterruptedException ignore) {
            }
            // 休息 5 秒后(模拟线程工作了 5 秒),调用 countDown()
            latch.countDown();
        }, "t1");

        // 任务线程2:
        Thread t2 = new Thread(() -> {
            try {
                Thread.sleep(10000);
            } catch (InterruptedException ignore) {
            }
            // 休息 10 秒后(模拟线程工作了 10 秒),调用 countDown()
            latch.countDown();
        }, "t2");

        // 线程1、2开始执行
        t1.start();
        t2.start();

        // 任务线程3:
        Thread t3 = new Thread(() -> {
            try {
                // 阻塞,等待 state 减为 0
                latch.await();
                System.out.println("线程 t3 从 await 中返回了");
            } catch (InterruptedException e) {
                System.out.println("线程 t3 await 被中断");
                Thread.currentThread().interrupt();
            }
        }, "t3");

        // 任务线程4:
        Thread t4 = new Thread(() -> {
            try {
                // 阻塞,等待 state 减为 0
                latch.await();
                System.out.println("线程 t4 从 await 中返回了");
            } catch (InterruptedException e) {
                System.out.println("线程 t4 await 被中断");
                Thread.currentThread().interrupt();
            }
        }, "t4");

        // 线程3、4开始执行
        t3.start();
        t4.start();
    }
}

执行结果:

线程 t4 从 await 中返回了
线程 t3 从 await 中返回了

3、源码分析

Sync内部类

  • CountDownLatch的Sync内部类继承AQS
private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;
    
    // 传入初始count次数
    Sync(int count) {
        setState(count);
    }
    
    // 获取还剩的count次数
    int getCount() {
        return getState();
    }
    
    // 尝试获取共享锁
    protected int tryAcquireShared(int acquires) {
        // 注意,这里state等于0的时候返回的是1,也就是说count减为0的时候获取锁总是成功
        // state不等于0的时候返回的是-1,也就是count不为0的时候总是要排队
        return (getState() == 0) ? 1 : -1;
    }
    
	// 尝试释放锁:
    // 更新 AQS.state 值,每调用一次,state值减一,当state -1 正好为0时,返回true
    protected boolean tryReleaseShared(int releases) {
        for (;;) {
            // 获取当前state的值(AQS.state)
            int c = getState();
            // 如果state等于0了,说明已释放锁,无法再释放了,这里返回false
            if (c == 0)
                return false;
            
            //执行到这里,说明 state > 0
            // 如果count>0,则将count的值减1
            int nextc = c-1;
            
            // 原子更新state的值:
            // cas成功,说明当前线程执行 tryReleaseShared 方法 c-1之前,没有其它线程 修改过 state。
            if (compareAndSetState(c, nextc))
                // 减为0的时候返回true,这时会唤醒后面排队的线程
                // 说明当前调用 countDown() 方法的线程就是需要触发 唤醒操作的线程!
                return nextc == 0;
        }
    }
}

Sync重写了tryAcquireShared()tryReleaseShared()方法,并把count存到state变量中去。这里要注意一下,上面两个方法的参数并没有被用到。


构造方法

// 构造方法需要传入一个count,也就是初始次数。
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    // 初始化Sync内部类: 
    this.sync = new Sync(count);
}

await()方法

await()方法是等待其它线程完成的方法,它会先尝试获取一下共享锁,如果失败则进入AQS的队列中排队等待被唤醒。

根据上面Sync的源码,我们知道,state不等于0的时候tryAcquireShared()返回的是-1,也就是说count未减到0的时候,所有调用await()方法的线程都要排队。

public void await() throws InterruptedException {
	// 调用AQS的acquireSharedInterruptibly()方法: 
    sync.acquireSharedInterruptibly(1);
}

AQS中的acquireInterruptibly方法:

// 位于AQS中:可以响应中断获取共享锁的方法
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    // CASE1: Thread.interrupted()
    // 条件成立:说明当前调用await方法的线程已经是中断状态,直接抛出异常即可~
    if (Thread.interrupted())
        throw new InterruptedException();
    
    // CASE2: tryAcquireShared(arg) < 0 注意:-1表示获取到了共享锁,1表示没有获取共享锁
    // 条件成立:说明当前AQS的state是大于0的,此时将线程入队,然后等待被唤醒
    // 条件不成立:说明AQS的state = 0,此时就不会阻塞线程:
    // 即,对应业务层面来说,执行任务的线程这时已经将latch打破了。然后其他再调用latch.await的线程,就不会在这里阻塞了
    if (tryAcquireShared(arg) < 0)
        // 采用共享中断模式
        doAcquireSharedInterruptibly(arg);
}

// 位于AQS中:采用共享中断模式
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 将调用latch.await()方法的线程包装成node加入到AQS的阻塞队列当中
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            // 获取当前线程节点的前驱节点
            final Node p = node.predecessor();
            // 条件成立,说明当前线程对应的节点为head.next节点
            if (p == head) {
                // head.next节点就有权利获取共享锁了..
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            // shouldParkAfterFailedAcquire 会给当前线程找一个好爸爸,最终给爸爸节点设置状态为 signal(-1),返回true
            // parkAndCheckInterrupt 挂起当前节点对应的线程...
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

图解分析:

请添加图片描述


countDown()方法

countDown()方法,会释放一个共享锁,也就是count的次数会减1。

根据上面Sync的源码,我们知道,tryReleaseShared()每次会把count的次数减1,当其减为0的时候返回true,这时候才会唤醒等待的线程。

注意,doReleaseShared()是唤醒等待的线程,这个方法我们在前面的章节中分析过了。

public void countDown() {
    // 释放共享锁
    sync.releaseShared(1);
}

// java.util.concurrent.locks.AbstractQueuedSynchronizer.releaseShared()
public final boolean releaseShared(int arg) {
    // tryReleaseShared(arg) 尝试释放共享锁,如果成功了,就唤醒排队的线程:
    // 条件成立:说明当前调用latch.countDown() 方法线程正好是 state - 1 == 0 的这个线程,需要做触发唤醒await状态的线程。
    if (tryReleaseShared(arg)) {// Sync的内部成员方法
        // 唤醒等待的线程:
        // 调用countDown()方法的线程 只有一个线程会进入到这个 if块 里面,去调用 doReleaseShared() 唤醒 阻塞状态的线程的逻辑。
        doReleaseShared();
        return true;
    }
    return false;
}

/**
  * 都有哪几种路径会调用到doReleaseShared方法呢?
  * 1.latch.countDown() -> AQS.state == 0 -> doReleaseShared() 唤醒当前阻塞队列内的 head.next 对应的线程。
  * 2.被唤醒的线程 -> doAcquireSharedInterruptibly parkAndCheckInterrupt() 唤醒 -> setHeadAndPropagate() -> doReleaseShared()
  */
// AQS.doReleaseShared
private void doReleaseShared() {
    for (;;) {
        // 获取当前AQS 内的 头结点
        Node h = head;
        // 条件一:h != null 成立,说明阻塞队列不为空..
        // 不成立:h == null 什么时候会是这样呢?
        // latch创建出来后,没有任何线程调用过 await() 方法之前,有线程调用latch.countDown()操作 且触发了 唤醒阻塞节点的逻辑..

        // 条件二:h != tail 成立,说明当前阻塞队列内,除了head节点以外  还有其他节点。
        // h == tail  -> head 和 tail 指向的是同一个node对象。 什么时候会有这种情况呢?
        // 1. 正常唤醒情况下,依次获取到 共享锁,当前线程执行到这里时 (这个线程就是 tail 节点。)
        // 2. 第一个调用await()方法的线程 与 调用countDown()且触发唤醒阻塞节点的线程 出现并发了..
        //   因为await()线程是第一个调用 latch.await()的线程,此时队列内什么也没有,它需要补充创建一个Head节点,然后再次自旋时入队
        //   在await()线程入队完成之前,假设当前队列内 只有 刚刚补充创建的空元素 head 。
        //   同期,外部有一个调用countDown()的线程,将state 值从1,修改为0了,那么这个线程需要做 唤醒 阻塞队列内元素的逻辑..
        //   注意:调用await()的线程 因为完全入队完成之后,再次回到上层方法 doAcquireSharedInterruptibly 会进入到自旋中,
        //   获取当前元素的前驱,判断自己是head.next, 所以接下来该线程又会将自己设置为 head,然后该线程就从await()方法返回了...
        if (h != null && h != tail) {
            // 执行到if里面,说明当前head 一定有 后继节点!

            int ws = h.waitStatus;
            // 当前head状态 为 signal 说明 后继节点并没有被唤醒过

以上是关于AQS源码探究_07 CountDownLatch源码分析的主要内容,如果未能解决你的问题,请参考以下文章

AQS源码探究_08 CyclicBarrier源码分析

AQS源码探究_08 CyclicBarrier源码分析

AQS源码探究_02 AQS简介及属性分析

AQS源码探究_02 AQS简介及属性分析

AQS源码探究_05 Conditon条件队列(手写一个入门的BrokingQueue)

AQS源码探究_05 Conditon条件队列(手写一个入门的BrokingQueue)