CountDownLatch和CyclicBarrier源码详解

2024-01-19 10:53:38 浏览数 (2)

前言

我现在有个场景:现在我有50个任务,这50个任务在完成之后,才能执行下一个函数,要是你,你怎么设计?

可以用JDK给我们提供的线程工具类,CountDownLatch和CyclicBarrier都可以完成这个需求。


一、CountDownLatch和CyclicBarrier

CountDownLatch允许一个或多个线程直等待,直到这些线程完成它们的操作。

而CyclicBarrier不一样,它往往是当线程到达某状态后,暂停下来等待其他线程等到所有线程均到达以后,才继续执行。

可以发现这两者的等待主体是不一样的。

CountDownLatch调用await()通常是主线程/调用线程,而CyclicBarrier调用await()是在任务线程调用的。

所以,CyclicBarrier中的阻塞的是任务的线程,而主线程是不受影响的。


二、CountDownLatch源码分析

CountDownLatch也是基于AQS实现的,它的实现机制很简单。

Java 中 AQS(AbstractQueuedSynchronizer)的 state 一般有以下值:

  1. 0:表示当前没有任何线程占有锁或者资源。
  2. 1:表示当前有一个线程占有锁或者资源。
  3. 大于等于 2:表示当前有多个线程占有锁或者资源,即存在竞争状态。
  4. 负数:表示当前有等待获取锁或者资源的线程,这些线程处于阻塞状态。
  5. 特殊值:AQS 实现类可以根据自己的需求自定义特殊的 state 值,例如 ReentrantLock 就使用了一个表示重入次数的 state 值。

当我们在构建CountDownLatch对象时,传入的值其实就会赋值给AQS的关键变量state。

如下:

代码语言:javascript复制
    Sync(int count) {
        setState(count);
    }
     /**
     * Constructs a {@code CountDownLatch} initialized with the given count.
     *
     * @param count the number of times {@link #countDown} must be invoked
     *        before threads can pass through {@link #await}
     * @throws IllegalArgumentException if {@code count} is negative
     */
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

执行countDown方法时,其实就是利用CAS将state -1。

如下:

代码语言:javascript复制
    /**
     * Decrements the count of the latch, releasing all waiting threads if
     * the count reaches zero.
     *
     * <p>If the current count is greater than zero then it is decremented.
     * If the new count is zero then all waiting threads are re-enabled for
     * thread scheduling purposes.
     *
     * <p>If the current count equals zero then nothing happens.
     */
    public void countDown() {
        sync.releaseShared(1);
    }

执行await方法时,其实就是判断state是否为0,不为0则加入到队列中,将该线程阻塞掉(除了头结点)。

如下:

代码语言:javascript复制
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
  
    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout);
    }

因为头节点会一直自旋等待state为0,当state为0时,头节点把剩余的在队列中阻塞的节点也一并唤醒。

如下:

代码语言:javascript复制
    private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (nanosTimeout <= 0L)
            return false;
        final long deadline = System.nanoTime()   nanosTimeout;
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return true;
                    }
                }
                nanosTimeout = deadline - System.nanoTime();
                if (nanosTimeout <= 0L)
                    return false;
                if (shouldParkAfterFailedAcquire(p, node) &&
                    nanosTimeout > spinForTimeoutThreshold)
                    LockSupport.parkNanos(this, nanosTimeout);
                if (Thread.interrupted())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

三、CyclicBarrier源码分析

回到CyclicBarrier代码也不难,重点也就是await方法。

从源码不难发现的是,它没有像CountDownLatch和ReentrantLock使用AQS的state变量,而CyclicBarrier是直接借助ReentrantLock加上Condition 等待唤醒的功能进而实现的。

如下:

代码语言:javascript复制
    /** The lock for guarding barrier entry */
    private final ReentrantLock lock = new ReentrantLock();
    /** Condition to wait on until tripped */
    private final Condition trip = lock.newCondition();

在构建CyclicBarrier时,传入的值会赋值给CyclicBarrier内部维护count变量,也会赋值给parties变量(这是可以复用的关键)。

如下:

代码语言:javascript复制
    /**
     * Creates a new {@code CyclicBarrier} that will trip when the
     * given number of parties (threads) are waiting upon it, and which
     * will execute the given barrier action when the barrier is tripped,
     * performed by the last thread entering the barrier.
     *
     * @param parties the number of threads that must invoke {@link #await}
     *        before the barrier is tripped
     * @param barrierAction the command to execute when the barrier is
     *        tripped, or {@code null} if there is no action
     * @throws IllegalArgumentException if {@code parties} is less than 1
     */
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

每次调用await时,会将count -1,操作count值是直接使用ReentrantLock来保证线程安全性。

如下:

代码语言:javascript复制
    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
            ...
            int index = --count;   
            ...          
    }

如果count不为0,则添加则condition队列中。

如下:

代码语言:javascript复制
            for (;;) {
                try {
                    if (!timed)
                        trip.await();
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        Thread.currentThread().interrupt();
                    }
                }
               ...
            }

如果count等于0时,则把节点从condition队列添加至AQS的队列中进行全部唤醒并且将parties的值重新赋值为count的值(实现复用)。

如下:

代码语言:javascript复制
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }

    /**
     * Updates state on barrier trip and wakes up everyone.
     * Called only while holding lock.
     */
    private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }

四、总结

CountDownlatch基于AQS实现,会将构造CountDownLatch的入参传递至statecountDown()就是在利用CAS将state减1,await)实际就是让头节点一直在等待state为0时,释放所有等待的线程。

CyclicBarrier则利用ReentrantLock和Condition,自身维护了count和parties变量。每次调用await将count-1,并将线程加入到condition队列上。等到count为0时,则将condition队列的节点移交至AQS队列,并全部释放。

0 人点赞