CountDownLatch

2021-01-14 15:40:11 浏览数 (1)

CountDownLatch简介

CountDownLatch是一个同步工具类,它使得一个或多个线程一直等待,直至其他线程的操作执行完成后再接着执行。在Java并发中,CountDownLatch是一个常见的面试题。

可以类比考场的考生交卷,考生交一份试卷,计数器就减1,直到考生都交了试卷(计数器为0),监考老师(一个或多个)才能离开考场。至于考生是否做完试卷,监考老师并不关注。只要都交了试卷,他就可以做接下来的工作了。

有任务A和任务B,任务B必须在任务A完成之后再做。而任务A还能被分为n部分,并且这n部分之间的任务互不影响。为了加快任务完成进度,把这n部分任务分给不同的线程,当A任务完成了,然后通知做B任务的线程接着完成任务,至于完成B任务的线程,可以是一个,也可以是多个。

CountDownLatch是什么?

CountDownLatch是在JDK 1.5中引入的,存在于java.util.concurrent包下。CountDownLatch这个类能够使一个线程等待其他线程完成各自的工作后再执行。例如,应用程序的主线程希望在负责启动框架服务的线程已经启动所有的框架服务之后再执行。

CountDownLatch是通过一个计数器来实现的,计数器的初始值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就会减1。当计数器值到达0时,它表示所有的线程已经完成了任务,然后在闭锁上等待的线程就可以恢复执行任务。

CountDownLatch的原理的伪代码如下:

1.Main thread start

2.Create CountDownLatch for N threads

3.Create and start N threads

4.Main thread wait on latch

5.N threads completes there tasks are returns

6.Main thread resume execution

CountDownLatch如何工作

CountDownLatch的构造函数如下:

代码语言:javascript复制
/**
 * Constructs a {@code CountDownLatch} initialized with the given count.
 *
 * @param count the number of times {@link #countDown} must be invoked
 *        before threads can pass through {@link #await}
 * @throws IllegalArgumentException if {@code count} is negative
 */
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

在CountDownLatch的构造器中的count代表的是CountDownLatch需要等待的线程数量。这个值只能被设置一次,CountDownLatch没有提供任何机制重置这个值。

CountDownLatch构造器实例化一个内部类Sync类的对象sync,Sync类的源码如下:

代码语言:javascript复制
/**
 * Synchronization control For CountDownLatch.
 * Uses AQS state to represent count.
 */
private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

从Sync类的源码可以看出,Sync实现了AbstractQueuedSynchronizer(简称AQS)。即CountDownLatch通过AQS实现功能。

在CountDownLatch的构造函数中使用count实例化一个Sync对象,即设置AQS的state=count.接下来看下CountDownLatch的等待方法。

CountDownLatch类的不带超时间等待方法如下:

代码语言:javascript复制
/**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
 *
 * <p>If the current count is zero then this method returns immediately.
 *
 * <p>If the current count is greater than zero then the current
 * thread becomes disabled for thread scheduling purposes and lies
 * dormant until one of two things happen:
 * <ul>
 * <li>The count reaches zero due to invocations of the
 * {@link #countDown} method; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * the current thread.
 * </ul>
 *
 * <p>If the current thread:
 * <ul>
 * <li>has its interrupted status set on entry to this method; or
 * <li>is {@linkplain Thread#interrupt interrupted} while waiting,
 * </ul>
 * then {@link InterruptedException} is thrown and the current thread's
 * interrupted status is cleared.
 *
 * @throws InterruptedException if the current thread is interrupted
 *         while waiting
 */
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

CountDownLatch类带超时时间的等待方法如下:

代码语言:javascript复制
/**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted},
 * or the specified waiting time elapses.
 *
 * <p>If the current count is zero then this method returns immediately
 * with the value {@code true}.
 *
 * <p>If the current count is greater than zero then the current
 * thread becomes disabled for thread scheduling purposes and lies
 * dormant until one of three things happen:
 * <ul>
 * <li>The count reaches zero due to invocations of the
 * {@link #countDown} method; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * the current thread; or
 * <li>The specified waiting time elapses.
 * </ul>
 *
 * <p>If the count reaches zero then the method returns with the
 * value {@code true}.
 *
 * <p>If the current thread:
 * <ul>
 * <li>has its interrupted status set on entry to this method; or
 * <li>is {@linkplain Thread#interrupt interrupted} while waiting,
 * </ul>
 * then {@link InterruptedException} is thrown and the current thread's
 * interrupted status is cleared.
 *
 * <p>If the specified waiting time elapses then the value {@code false}
 * is returned.  If the time is less than or equal to zero, the method
 * will not wait at all.
 *
 * @param timeout the maximum time to wait
 * @param unit the time unit of the {@code timeout} argument
 * @return {@code true} if the count reached zero and {@code false}
 *         if the waiting time elapsed before the count reached zero
 * @throws InterruptedException if the current thread is interrupted
 *         while waiting
 */
public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

在CountDownLatch类的等待方法中,以await()方法为例,其会调用其父类AQS中的acquireSharedInterruptibly()方法。

代码语言:javascript复制
/**
 * Acquires in shared mode, aborting if interrupted.  Implemented
 * by first checking interrupt status, then invoking at least once
 * {@link #tryAcquireShared}, returning on success.  Otherwise the
 * thread is queued, possibly repeatedly blocking and unblocking,
 * invoking {@link #tryAcquireShared} until success or the thread
 * is interrupted.
 * @param arg the acquire argument.
 * This value is conveyed to {@link #tryAcquireShared} but is
 * otherwise uninterpreted and can represent anything
 * you like.
 * @throws InterruptedException if the current thread is interrupted
 */
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

其中会调用Sync类的tryAcquireShared()方法。

代码语言:javascript复制
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

在tryAcquireShared方法中,会判断父类AQS中的state变量的状态值是否为0。

当AQS中的state变量的状态值为0的时候,tryAcquireShared()方法返回1,因此acquireSharedInterruptibly()方法就直接退出了。如果AQS中的state变量的状态值不为0,那么将会执行doAcquireSharedInterruptibly()方法。

代码语言:javascript复制
private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);// 往同步队列中添加节点
    boolean failed = true;
    try {
        for (;;) {// 一个死循环 跳出循环只有下面两个途径
            final Node p = node.predecessor();// 当前线程的前一个节点
            if (p == head) {// 如果是首节点
                int r = tryAcquireShared(arg);// 这个是不是似曾相识 见上面
                if (r >= 0) {
                    setHeadAndPropagate(node, r);// 处理后续节点
                    p.next = null; // help GC 这个可以借鉴
                    failed = false;
                    return;// 计数值为0 并且为头节点 跳出循环
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();// 跳出循环
        }
    } finally {
        if (failed)
            cancelAcquire(node);// 如果是打断退出的 则移除同步队列节点
    }
}

addWaiter()方法的源码如下:

代码语言:javascript复制
private Node addWaiter(Node mode) {
    Node node = new Node(Thread.currentThread(), mode);// 包装节点
    Node pred = tail;// 同步队列尾节点
    if (pred != null) {// 同步队列有尾节点 将我们的节点通过cas方式添加到队列后面
        node.prev = pred;
        if (compareAndSetTail(pred, node)) {// 以cas原子方式添加尾节点
            pred.next = node;
            return node;// 退出该方法
        }
    }
    enq(node);// 两种情况执行这个代码 1.队列尾节点为null 2.队列尾节点不为null,但是我们原子添加尾节点失败
    return node;
}

private Node enq(final Node node) {
    for (;;) {// 又是一个死循环
        Node t = tail;
        if (t == null) { // Must initialize
            if (compareAndSetHead(new Node()))// cas形式添加头节点  注意 是头节点
                tail = head;
        } else {
            node.prev = t;
            if (compareAndSetTail(t, node)) {// cas形式添加尾节点
                t.next = node;
                return t;// 结束这个方法的唯一出口 添加尾节点成功
            }
        }
    }
}

shouldParkAfterFailedAcquire()和parkAndCheckInterrupt()方法的源码如下:

代码语言:javascript复制
/**
 * Checks and updates status for a node that failed to acquire.
 * Returns true if thread should block. This is the main signal
 * control in all acquire loops.  Requires that pred == node.prev.
 *
 * @param pred node's predecessor holding status
 * @param node the node
 * @return {@code true} if thread should block
 */
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
    int ws = pred.waitStatus;
    if (ws == Node.SIGNAL)
        /*
         * This node has already set status asking a release
         * to signal it, so it can safely park.
         */
        return true;
    if (ws > 0) {
        /*
         * Predecessor was cancelled. Skip over predecessors and
         * indicate retry.
         */
        do {
            node.prev = pred = pred.prev;
        } while (pred.waitStatus > 0);
        pred.next = node;
    } else {
        /*
         * waitStatus must be 0 or PROPAGATE.  Indicate that we
         * need a signal, but don't park yet.  Caller will need to
         * retry to make sure it cannot acquire before parking.
         */
        compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
    }
    return false;
}

/**
 * Convenience method to park and then check if interrupted
 *
 * @return {@code true} if interrupted
 */
private final boolean parkAndCheckInterrupt() {
    LockSupport.park(this);
    return Thread.interrupted();
}

最后通过LockSupport.park()阻断线程执行。

CountDownLatch.countDown()方法如下:

代码语言:javascript复制
/**
 * Decrements the count of the latch, releasing all waiting threads if
 * the count reaches zero.
 *
 * <p>If the current count is greater than zero then it is decremented.
 * If the new count is zero then all waiting threads are re-enabled for
 * thread scheduling purposes.
 *
 * <p>If the current count equals zero then nothing happens.
 */
public void countDown() {
    sync.releaseShared(1);
}

countDown()方法会调用AQS的releaseShared()方法。

代码语言:javascript复制
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

tryReleaseShared()方法如下:

代码语言:javascript复制
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

doReleaseShared()方法如下:

代码语言:javascript复制
/**
 * Release action for shared mode -- signals successor and ensures
 * propagation. (Note: For exclusive mode, release just amounts
 * to calling unparkSuccessor of head if it needs signal.)
 */
private void doReleaseShared() {
    /*
     * Ensure that a release propagates, even if there are other
     * in-progress acquires/releases.  This proceeds in the usual
     * way of trying to unparkSuccessor of head if it needs
     * signal. But if it does not, status is set to PROPAGATE to
     * ensure that upon release, propagation continues.
     * Additionally, we must loop in case a new node is added
     * while we are doing this. Also, unlike other uses of
     * unparkSuccessor, we need to know if CAS to reset status
     * fails, if so rechecking.
     */
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

与CountDownLatch的第一次交互是主线程等待其他线程。主线程必须在启动其他线程后立即调用CountDownLatch.await()方法。这样主线程的操作就会在这个方法上阻塞,直到其他线程完成各自的任务。

其他N 个线程必须引用CountDownLatch对象,因为他们需要通知CountDownLatch对象,他们已经完成了各自的任务。这种通知机制是通过 CountDownLatch.countDown()方法来完成的;每调用一次这个方法,在构造函数中初始化的count值就减1。所以当N个线程都调用了这个方法,count的值等于0,然后主线程就能通过await()方法,恢复执行自己的任务。

使用场景

1.实现最大的并行性

有时我们想同时启动多个线程,实现最大程度的并行性。例如,我们想测试一个单例类。如果我们创建一个初始计数为1的CountDownLatch,并让所有线程都在这个锁上等待,那么我们可以很轻松地完成测试。我们只需调用一次countDown()方法就可以让所有的等待线程同时恢复执行。

2.开始执行前等待n个线程完成各自任务:例如应用程序启动类要确保在处理用户请求前,所有N个外部系统已经启动和运行了。

3.死锁检测:一个非常方便的使用场景是,你可以使用n个线程访问共享资源,在每次测试阶段的线程数目是不同的,并尝试产生死锁。

0 人点赞