并发工具CountDownLatch源码分析

时间:2023-12-16 11:44:02

  CountDownLatch的作用类似于Thread.join()方法,但比join()更加灵活。它可以等待多个线程(取决于实例化时声明的数量)都达到预期状态或者完成工作以后,通知其他正在等待的线程继续执行。简单的说,Thread.join()是等待具体的一个线程执行完毕,CountDownLatch等待多个线程。

  如果需要统计4个文件中的内容行数,可以用4个线程分别执行,然后用一个线程等待统计结果,最后执行数据汇总。这样场景就适合使用CountDownLatch。

  本篇从CountDownLatch的源码分析它的原理机制。再给出一个简单的使用案例。

  

  首先认识一下CountDownLatch中的内部类:

private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L; Sync(int count) {
setState(count); // 更新AQS中的state
} int getCount() {
return getState();
} protected int tryAcquireShared(int acquires) {
return (getState() == ) ? : -;
} protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == )
return false;
int nextc = c-;
if (compareAndSetState(c, nextc))
return nextc == ;
}
}
}

  其实CountDownLatch的机制和ReentrantLock有点像,都是利用AQS(AbstractQueuedSynchronizer)来实现的。CountDownLatch的内部类Sync继承AQS,重写了tryAcquireShared()方法和tryReleaseShared()方法。这里的重点是CountDownLatch的构造函数需要传入一个int值count,就是等待的线程数。这个count被Sync用来直接更新为AQS中的state。

  

1、await()等待方法

//CountDownLatch
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly();
}
//AQS
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < ) // 1
doAcquireSharedInterruptibly(arg); // 2  
}
//Sync
protected int tryAcquireShared(int acquires) {
return (getState() == ) ? : -;
}
//AQS
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= ) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
  1. 调用AQS中的tryAcquireShared()方法时,Sync重写了tryAcquireShared()方法,获取state,判断state是否为0。
  2. 如果不为0,调用doAcquireSharedInterruptibly()方法,将线程加入队列,挂起线程。

2、countDown()

public void countDown() {
sync.releaseShared();
}
//AQS
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
//Sync
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == )
return false;
int nextc = c-;
if (compareAndSetState(c, nextc))
return nextc == ;
}
}

  重点也是在于Sync重写的tryReleaseShared()方法。利用CAS算法将state减1。如果state减到0,说明所有工作线程都执行完毕,那么就唤醒等待队列中的线程。

使用示例:

public class CountDownLatchTest {
private static CountDownLatch countDownLatch = new CountDownLatch();
private static ThreadPoolExecutor threadPool = new ThreadPoolExecutor(, ,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(10)); public static void main(String[] args) {
//等待线程
for (int i = ; i < ; i++) {
String threadName = "等待线程 " + i;
threadPool.execute(new Runnable() { @Override
public void run() {
try {
System.out.println(threadName + " 正在等待...");
//等待
countDownLatch.await();
System.out.println(threadName + " 结束等待...");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
});
}
//工作线程
for (int i = ; i < ; i++) {
String threadName = "工作线程 " + i;
threadPool.execute(new Runnable() { @Override
public void run() {
try {
System.out.println(threadName + " 进入...");
//沉睡1秒
TimeUnit.MILLISECONDS.sleep();
System.out.println(threadName + " 完成...");
//通知
countDownLatch.countDown();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
});
} threadPool.shutdown();
}
}

  执行结果为:

等待线程 1 正在等待...
等待线程 0 正在等待...
工作线程 2 进入...
工作线程 3 进入...
工作线程 4 进入...
工作线程 3 完成...
工作线程 2 完成...
工作线程 4 完成...
等待线程 0 结束等待...
等待线程 1 结束等待...

  从结果也能看到,等待线程先执行,调用countDownLatch.await()方法开始等待。每个工作线程工作完成以后,都调用countDownLatch.countDown()方法,告知自己的任务完成。countDownLatch初始参数为3,所以3个工作线程都告知自己结束以后,等待线程才开始工作。