一个简单的数据库连接池示例

时间:2022-12-11 08:57:13

  一个简单的数据库连接池, 即一个通过构造函数初始化连接的最大上限,并通过一个双向队列来维护连接,调用方需要先调用fetchConnection(long)方法来指定在多少毫秒内超时获取连接,连接使用完成后,需要调用releaseConnection(Connection)方法将连接放回线程池,连接池的代码实例如下:

/**
* 简易连接池
*/

package com.lbbywyt.concurrent;
import java.sql.Connection;
import java.util.LinkedList;
/**
* @author Administrator
*
*/

public class ConnectionPool {
//
private LinkedList<Connection> pool = new LinkedList<Connection>();
//构造耗时,初始化连接池大小.
public ConnectionPool(int initialSize) {
if (initialSize > 0) {
for (int i = 0; i < initialSize; i++) {
pool.addLast(ConnectionDriver.createConnection());
}
}
}
/**
* 释放连接.
* @param connection connection
*/

public void releaseConnection(Connection connection) {
if (connection != null) {
synchronized (pool) {
// 连接释放后需要进行通知,这样其他消费者能够感知到连接池中已经归还了一个连接
pool.addLast(connection);
pool.notifyAll();
}
}
}
// 在mills内无法获取到连接,将会返回null
public Connection fetchConnection(long mills) throws InterruptedException {
synchronized (pool) {
// 不设超时
if (mills <= 0) {
//不设置等待时间时,只要连接池不为空,即返回一个连接
while (pool.isEmpty()) {
pool.wait();
}
return pool.removeFirst();
} else {
long future = System.currentTimeMillis() + mills;
long remaining = mills;
while (pool.isEmpty() && remaining > 0) {
pool.wait(remaining);
remaining = future - System.currentTimeMillis();
}
Connection result = null;
if (!pool.isEmpty()) {
result = pool.removeFirst();
}
return result;
}
}
}
}

  由于java.sql.Connection是一个接口,最终的实现是由数据库驱动提供方来实现的,书中通过动态代理构造了一个Connection,该Connection的代理实现仅仅是在commit()方法调用时休眠100毫秒,示例如代码:

/**
*该代理类用于创建数据库连接.
*/

package com.lbbywyt.concurrent;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.util.concurrent.TimeUnit;

/**
* @author Administrator
*/

public class ConnectionDriver {
// 每一个动态代理类都必须要实现InvocationHandler这个接口,并且每个代理类的实例都关联到了一个handler,当我们通过代理对象调用一个方法的时候,这个方法的调用就会被转发为由InvocationHandler这个接口的
// invoke 方法来进行调用。
static class ConnectionHandler implements InvocationHandler {

// proxy:  指代我们所代理的那个真实对象
// method:  指代的是我们所要调用真实对象的某个方法的Method对象
// args:  指代的是调用真实对象某个方法时接受的参数
public Object invoke(Object proxy, Method method, Object[] args)
throws Throwable {
if (method.getName().equals("commit")) {
TimeUnit.MILLISECONDS.sleep(100);
}
return null;
}
}
// 创建一个Connection的代理,在commit时休眠100毫秒
public static final Connection createConnection() {
// newProxyInstance返回一个动态代理对象。

// loader:  一个ClassLoader对象,定义了由哪个ClassLoader对象来对生成的代理对象进行加载
// interfaces:  一个Interface对象的数组,表示的是我将要给我需要代理的对象提供一组什么接口,如果我提供了一组接口给它,那么这个代理对象就宣称实现了该接口(多态),这样我就能调用这组接口中的方法了
// h:  一个InvocationHandler对象,表示的是当我这个动态代理对象在调用方法的时候,会关联到哪一个InvocationHandler对象上
// 在newProxyInstance这个方法的第二个参数上,我们给这个代理对象提供了一组什么接口,那么我这个代理对象就会实现了这组接口,这个时候我们当然可以将这个代理对象强制类型转化为这组接口中的任意一个,
// 实例中即当我们调用connection对象的commit方法时sleep 1秒.
return (Connection) Proxy.newProxyInstance(
ConnectionDriver.class.getClassLoader(),
new Class<?>[] { Connection.class }, new ConnectionHandler());
}
}

  最后创建一个客户端类来模拟从连接池中获取连接,客户端类中使用了同步工具类CountDownLatch来保证每个线程获取线程时的公平性。CountDownLatch允许一个或多个线程一直等待,直到其他线程的操作执行完后再执行,它通过一个计数器来实现,计数器的初始值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就会减1。当计数器值到达0时,它表示所有的线程已经完成了任务。

/**
* 连接池客户端.
*/

package com.lbbywyt.concurrent;

import java.sql.Connection;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

/**
* @author libaobao
*
*/

public class ConnectionPoolTest {
static ConnectionPool pool = new ConnectionPool(10);
// CountDownLatch,同步工具类,它允许一个或多个线程一直等待,直到其他线程的操作执行完后再执行。
// CountDownLatch是通过一个计数器来实现的,计数器的初始值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就会减1。当计数器值到达0时,它表示所有的线程已经完成了任务
// 保证所有ConnectionRunner能够同时开始
static CountDownLatch start = new CountDownLatch(1);
// main线程将会等待所有ConnectionRunner结束后才能继续执行
static CountDownLatch end;

/**
* @param args
* @throws InterruptedException
*/

public static void main(String[] args) throws InterruptedException {
//线程数量10
int threadCount = 50;
end = new CountDownLatch(threadCount);
//每个线程20次尝试
int count = 20;
AtomicInteger got = new AtomicInteger();
AtomicInteger notGot = new AtomicInteger();
for (int i = 0; i < threadCount; i++) {
Thread thread = new Thread(new ConnetionRunner(count, got, notGot),
"ConnectionRunnerThread");
thread.start();
}

//start 使所有线程创建完之后才去获取连接,保证公平性。
start.countDown();

end.await();//等待所有线程都执行完之后才执行下面3个输出语句
System.out.println("total invoke: " + (threadCount * count));
System.out.println("got connection: " + got);
System.out.println("not got connection " + notGot);

}

static class ConnetionRunner implements Runnable {
int count;
AtomicInteger got;
AtomicInteger notGot;

public ConnetionRunner(int count, AtomicInteger got,
AtomicInteger notGot) {
this.count = count;
this.got = got;
this.notGot = notGot;
}

public void run() {
try {
Profiler.begin();
start.await();
} catch (Exception ex) {
}
while (count > 0) {
try {
// 从线程池中获取连接,如果1000ms内无法获取到,将会返回null
// 分别统计连接获取的数量got和未获取到的数量notGot
Connection connection = pool.fetchConnection(1000);
if (connection != null) {
try {
connection.createStatement();
connection.commit();
} finally {
pool.releaseConnection(connection);
got.incrementAndGet();
}
} else {
notGot.incrementAndGet();
}
} catch (Exception ex) {
ex.printStackTrace();
} finally {
count--;

}
}
Profiler.end();
end.countDown();
}
}

}

上述客户端类,共创建了50个线程,每个线程分别尝试获取20次,测试输出如下。

……
Cost: 8204 mills
Cost: 8304 mills
Cost: 8305 mills
Cost: 8504 mills
Cost: 8905 mills
total invoke: 1000
got connection: 824
not got connection: 176

Profiler.begin()和Profiler.end();用来记录每个线程获取连接的耗时,具体代码如下:

/**
* 使用ThreadLocal实现计算方法调用耗时.
*/

package com.lbbywyt.concurrent;

import java.util.concurrent.TimeUnit;

/**
* @author libaobao
*
*/

public class Profiler {

// 第一次get()方法调用时会进行初始化(如果set方法没有调用),每个线程会调用一次
private static final ThreadLocal<Long> TIME_THREADLOCAL = new ThreadLocal<Long>() {
protected Long initialValue() {
return System.currentTimeMillis();
}
};

public static final void begin() {
TIME_THREADLOCAL.set(System.currentTimeMillis());
}

public static final long end() {
long millis = System.currentTimeMillis() - TIME_THREADLOCAL.get();

System.out.println("Cost: " + millis+ " mills");

return System.currentTimeMillis() - TIME_THREADLOCAL.get();
}

public static void main(String[] args) throws Exception {
Profiler.begin();
TimeUnit.SECONDS.sleep(1);
System.out.println("Cost: " + Profiler.end() + " mills");
}

}