CountDownLatch 和 CyclicBarrier 的运用及实现原理

时间:2021-11-13 10:32:29

I.CountDownLatch 和 CyclicBarrier 的运用

CountDownlatch:

定义: 其是一个线程同步的辅助工具,通过它可以做到使一条线程一直阻塞等待,直到其他线程完成其所处理的任务。一个特性就是它不要求调用countDown方法的线程等到计数到达0时才继续,而在所有线程都能通过之前,它只是阻止任何线程继续通过一个await

用法:用给定的计数初始化CountDownLath。调用countDown()方法计数减 1,在计数被减到 0之前,调用await方法会一直阻塞。减为 0之后,则会迅速释放所有阻塞等待的线程,并且调用await操作会立即返回。

场景:(1)将CountDownLatch 的计数置为 1,此时CountDownLath 可以用作一个肩带的开/关锁存器或入口,在通过调用countDown()的线程打开入口前,所有调用await的线程会一直在入口处等待。(2)用 N (N >= 1) 初始化的CountDownLatch 可以是一条线程在N个线程完成某项操作之前一直等待,或者使其在某项操作完成 N 次之前一直等待。

ps:CountDownLath计数无法被重置,如果需要重置计数,请考虑使用CyclicBarrier.

实践: 下面用代码实现10条线程分别计算一组数字,要求者10条线程逻辑上同时开始计算(其实并不能做到同时,CPU核不够,不能达到并行计算),并且10条线程中如果有任何一条线程没有计算完成之前,谁都不允许提前返回。

MyCalculator.java:

package simple.demo;

import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
/**
* @author jianying.wcj
* @date 2013-8-2
*/
public class MyCalculator implements Callable<Integer> {
/**
* 开始开关
*/
private CountDownLatch startSwitch;
/**
* 结束开关
*/
private CountDownLatch stopSwitch;
/**
* 要计算的分组数
*/
private int groupNum;
/**
* 构造函数
*/
public MyCalculator(CountDownLatch startSwitch,CountDownLatch stopSwitch,Integer groupNum) {
this.startSwitch = startSwitch;
this.stopSwitch = stopSwitch;
this.groupNum = groupNum;
} @Override
public Integer call() throws Exception { startSwitch.await();
int res = compute();
System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");
stopSwitch.countDown();
stopSwitch.await();
System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);
return res;
}
/**
* 累计求和
* @return
* @throws InterruptedException
*/
public int compute() throws InterruptedException {
int sum = 0;
for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {
sum += i;
}
return sum;
} }

MyTest.java:

package simple.demo;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future; public class MyTest { private int groupNum = 10;
/**
* 开始和结束开关
*/
private CountDownLatch startSwitch = new CountDownLatch(1); private CountDownLatch stopSwitch = new CountDownLatch(groupNum);
/**
* 线程池
*/
private ExecutorService service = Executors.newFixedThreadPool(groupNum);
/**
* 保存计算结果
*/
private List<Future<Integer>> result = new ArrayList<Future<Integer>>();
/**
* 启动groupNum条线程计算数值
*/
public void init() { for(int i = 1; i <= groupNum; i++) {
result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));
}
System.out.println("init is ok!");
} public void printRes() throws InterruptedException, ExecutionException { int sum = 0; for(Future<Integer> f : result) {
sum += f.get();
}
System.out.println("the result is "+sum);
} public void start() {
this.startSwitch.countDown();
} public void stop() throws InterruptedException {
this.stopSwitch.await();
this.service.shutdown();
} public static void main(String[] args) throws IOException, InterruptedException, ExecutionException { BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); MyTest myTest = new MyTest();
myTest.init();
System.out.println("please enter start command...."); reader.readLine();
myTest.start();
myTest.stop(); myTest.printRes();
} }

运行结果:

init is ok!
please enter start command.... pool-1-thread-1 is ok wait other thread...
pool-1-thread-2 is ok wait other thread...
pool-1-thread-3 is ok wait other thread...
pool-1-thread-4 is ok wait other thread...
pool-1-thread-6 is ok wait other thread...
pool-1-thread-5 is ok wait other thread...
pool-1-thread-8 is ok wait other thread...
pool-1-thread-7 is ok wait other thread...
pool-1-thread-9 is ok wait other thread...
pool-1-thread-10 is ok wait other thread...
pool-1-thread-10 is stop! the group10 temp result is sum=955
pool-1-thread-1 is stop! the group1 temp result is sum=55
pool-1-thread-2 is stop! the group2 temp result is sum=155
pool-1-thread-3 is stop! the group3 temp result is sum=255
pool-1-thread-4 is stop! the group4 temp result is sum=355
pool-1-thread-6 is stop! the group6 temp result is sum=555
pool-1-thread-5 is stop! the group5 temp result is sum=455
pool-1-thread-8 is stop! the group8 temp result is sum=755
pool-1-thread-7 is stop! the group7 temp result is sum=655
pool-1-thread-9 is stop! the group9 temp result is sum=855
the result is 5050

CyclicBarrier.java:

定义:其是一个同步辅助类,它允许一组线程互相等待,直到到达某个公共的屏障点,所有线程一起继续执行或者返回。一个特性就是CyclicBarrier支持一个可选的Runnable命令,在一组线程中的最后一个线程到达之后,该命令只在每个屏障点运行一次。若在继续所有参与线程之前更新此共享状态,此屏障操作很有用。

用法:用计数 N 初始化CyclicBarrier, 每调用一次await,线程阻塞,并且计数+1(计数起始是0),当计数增长到指定计数N时,所有阻塞线程会被唤醒。继续调用await也将迅速返回。

场景:用N初始化CyclicBarrier,可以在N线程中分布调用await方法,可以控制N调线程都执行到await方法后,一起继续执行。

实践:和CountDownLatch实践相同,见上文:

MyCalculator.java:

package simple.demo;

import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier; public class MyCalculator implements Callable<Integer> {
/**
* 开始开关
*/
private CyclicBarrier startSwitch;
/**
* 结束开关
*/
private CyclicBarrier stopSwitch;
/**
* 要计算的分组数
*/
private int groupNum;
/**
* 构造函数
*/
public MyCalculator(CyclicBarrier startSwitch,CyclicBarrier stopSwitch,Integer groupNum) {
this.startSwitch = startSwitch;
this.stopSwitch = stopSwitch;
this.groupNum = groupNum;
} @Override
public Integer call() throws Exception { startSwitch.await();
int res = compute();
System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");
stopSwitch.await();
System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);
return res;
}
/**
* 累计求和
* @return
* @throws InterruptedException
*/
public int compute() throws InterruptedException {
int sum = 0;
for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {
sum += i;
}
return sum;
}}

MyTest.java:

package simple.demo;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future; public class MyTest { private int groupNum = 10;
/**
* 开始和结束开关
*/
private CyclicBarrier startSwitch = new CyclicBarrier(groupNum+1); private CyclicBarrier stopSwitch = new CyclicBarrier(groupNum);
/**
* 线程池
*/
private ExecutorService service = Executors.newFixedThreadPool(groupNum);
/**
* 保存计算结果
*/
private List<Future<Integer>> result = new ArrayList<Future<Integer>>();
/**
* 启动groupNum条线程计算数值
*/
public void init() { for(int i = 1; i <= groupNum; i++) {
result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));
}
System.out.println("init is ok!");
} public void printRes() throws InterruptedException, ExecutionException { int sum = 0; for(Future<Integer> f : result) {
sum += f.get();
}
System.out.println("the result is "+sum);
} public void start() throws InterruptedException, BrokenBarrierException {
this.startSwitch.await();
} public void stop() throws InterruptedException { this.service.shutdown();
} public static void main(String[] args) throws IOException, InterruptedException, ExecutionException, BrokenBarrierException { BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); MyTest myTest = new MyTest();
myTest.init();
System.out.println("please enter start command...."); reader.readLine(); myTest.start();
myTest.stop(); myTest.printRes();
}
}

运行结果:

init is ok!
please enter start command.... pool-1-thread-1 is ok wait other thread...
pool-1-thread-2 is ok wait other thread...
pool-1-thread-3 is ok wait other thread...
pool-1-thread-4 is ok wait other thread...
pool-1-thread-5 is ok wait other thread...
pool-1-thread-6 is ok wait other thread...
pool-1-thread-7 is ok wait other thread...
pool-1-thread-8 is ok wait other thread...
pool-1-thread-9 is ok wait other thread...
pool-1-thread-10 is ok wait other thread...
pool-1-thread-10 is stop! the group10 temp result is sum=955
pool-1-thread-1 is stop! the group1 temp result is sum=55
pool-1-thread-2 is stop! the group2 temp result is sum=155
pool-1-thread-3 is stop! the group3 temp result is sum=255
pool-1-thread-5 is stop! the group5 temp result is sum=455
pool-1-thread-6 is stop! the group6 temp result is sum=555
pool-1-thread-4 is stop! the group4 temp result is sum=355
pool-1-thread-8 is stop! the group8 temp result is sum=755
pool-1-thread-7 is stop! the group7 temp result is sum=655
pool-1-thread-9 is stop! the group9 temp result is sum=855
the result is 5050

II.CountDownLatch 和 CyclicBarrier的实现原理

CountDownLatch的类图如下:

CountDownLatch 和 CyclicBarrier 的运用及实现原理

CountDownLatch的实现是基于AQS的,其实现了一个sync的内部类,而sync继承了AQS。关键的源代码如下:
await方法

 /**
* Causes the current thread to wait until the latch has counted down to
* zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
*
* <p>If the current count is zero then this method returns immediately.
*
* <p>If the current count is greater than zero then the current
* thread becomes disabled for thread scheduling purposes and lies
* dormant until one of two things happen:
* <ul>
* <li>The count reaches zero due to invocations of the
* {@link #countDown} method; or
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread.
* </ul>
*
* <p>If the current thread:
* <ul>
* <li>has its interrupted status set on entry to this method; or
* <li>is {@linkplain Thread#interrupt interrupted} while waiting,
* </ul>
* then {@link InterruptedException} is thrown and the current thread's
* interrupted status is cleared.
*
* @throws InterruptedException if the current thread is interrupted
* while waiting
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

CyclicBarrier的类图如下:

CountDownLatch 和 CyclicBarrier 的运用及实现原理

/**
* Decrements the count of the latch, releasing all waiting threads if
* the count reaches zero.
*
* <p>If the current count is greater than zero then it is decremented.
* If the new count is zero then all waiting threads are re-enabled for
* thread scheduling purposes.
*
* <p>If the current count equals zero then nothing happens.
*/
public void countDown() {
sync.releaseShared(1);
}

以上是CountDownLatch的两个关键方法 await 和 countDown 的定义。具体的方法通过注释能够理解,其实CountDownLatch只是简单的利用了 AQS 的 state 属性(表示锁可重入的次数),CountDownLatch 的内部类 sync 重写了 AQS 的 tryAcquireShared,CountDownLatch 的 tryAcquireShared 方法的定义是:

public int tryAcquireShared(int acquires) {
return getState() == 0? 1 : -1;
}

state的初始值就是初始化 CountDownLatch 时的计数器,在 sync 调用 AQS 的 acquireSharedInterruptibly的时候会判断 tryAcquireShared(int acquires) 是否大于 0,如果小于 0,会将线程挂起。具体的AQS当中挂起线程的方法是:

 /**
* Acquires in shared interruptible mode.
* @param arg the acquire argument
*/
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
break;
}
} catch (RuntimeException ex) {
cancelAcquire(node);
throw ex;
}
// Arrive here only if interrupted
cancelAcquire(node);
throw new InterruptedException();
}

在CountDownLatch调用countDown方法时,会调用CountDownLatch中内部类sync重写AQS的方法tryReleaseShared,方法的定义如下:

public boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

可见没调用一次都会将state减1,直到等于 0。CountDownLatch就先说这么多。

CyclicBarrier的类图如下:

CountDownLatch 和 CyclicBarrier 的运用及实现原理

CyclicBarrier的实现是基于ReentrantLock的,而ReentrantLock是基于AQS的,说白了CyclicBarrier最终还是基于AQS的。CyclicBarrier内部使用ReentrantLock的Condition来唤醒栅栏前的线程,关键源代码如下:
await方法:

/**
* Waits until all {@linkplain #getParties parties} have invoked
* <tt>await</tt> on this barrier.
*
* <p>If the current thread is not the last to arrive then it is
* disabled for thread scheduling purposes and lies dormant until
* one of the following things happens:
* <ul>
* <li>The last thread arrives; or
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* one of the other waiting threads; or
* <li>Some other thread times out while waiting for barrier; or
* <li>Some other thread invokes {@link #reset} on this barrier.
* </ul>
*
* <p>If the current thread:
* <ul>
* <li>has its interrupted status set on entry to this method; or
* <li>is {@linkplain Thread#interrupt interrupted} while waiting
* </ul>
* then {@link InterruptedException} is thrown and the current thread's
* interrupted status is cleared.
*
* <p>If the barrier is {@link #reset} while any thread is waiting,
* or if the barrier {@linkplain #isBroken is broken} when
* <tt>await</tt> is invoked, or while any thread is waiting, then
* {@link BrokenBarrierException} is thrown.
*
* <p>If any thread is {@linkplain Thread#interrupt interrupted} while waiting,
* then all other waiting threads will throw
* {@link BrokenBarrierException} and the barrier is placed in the broken
* state.
*
* <p>If the current thread is the last thread to arrive, and a
* non-null barrier action was supplied in the constructor, then the
* current thread runs the action before allowing the other threads to
* continue.
* If an exception occurs during the barrier action then that exception
* will be propagated in the current thread and the barrier is placed in
* the broken state.
*
* @return the arrival index of the current thread, where index
* <tt>{@link #getParties()} - 1</tt> indicates the first
* to arrive and zero indicates the last to arrive
* @throws InterruptedException if the current thread was interrupted
* while waiting
* @throws BrokenBarrierException if <em>another</em> thread was
* interrupted or timed out while the current thread was
* waiting, or the barrier was reset, or the barrier was
* broken when {@code await} was called, or the barrier
* action (if present) failed due an exception.
*/
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen;
}
}

私有的 dowait 方法:

 /**
* Main barrier code, covering the various policies.
*/
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
final Generation g = generation; if (g.broken)
throw new BrokenBarrierException(); if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
} int index = --count;
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
} // loop until tripped, broken, interrupted, or timed out
for (;;) {
try {
if (!timed)
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
} if (g.broken)
throw new BrokenBarrierException(); if (g != generation)
return index; if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}

从doAwait方法中可以看到,没调用一次index 会减1,当减为 0时,会调用 breakBarrier()方法。 breakBarrier方法的实现是:

 /**
* Sets current barrier generation as broken and wakes up everyone.
* Called only while holding lock.
*/
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}

会调用 trip.signalAll()唤醒所有的线程(trip的定义 Condition trip = lock.newCondition())。可见 CyclicBarrier 是对独占锁 ReentrantLock 的简单利用。