接口:
import java.util.concurrent.TimeUnit;
public interface DistributedLock {
public void getLock() throws Exception;
public boolean getLock(long time,TimeUnit timeUnit) throws Exception;
public void releaseLock();
}
核心公用代码:
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.I0Itec.zkclient.IZkDataListener;
import org.I0Itec.zkclient.ZkClient;
import org.I0Itec.zkclient.exception.ZkNoNodeException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
public class BaseDistributedLock {
private final static Logger logger = LoggerFactory.getLogger(BaseDistributedLock.class);
private final ZkClient client;
private final String lockPath;
private final String baseLockName;
private final String lockName;
private ThreadPoolTaskExecutor taskExecutor;
private static final Integer MAX_RETRY_COUNT = 10;
public BaseDistributedLock(ZkClient zkClient,String baseLockName,String lockName,ThreadPoolTaskExecutor taskExecutor){
this.client = zkClient;
this.baseLockName = baseLockName;
this.lockName = lockName;
this.lockPath = baseLockName.concat("/").concat(lockName);
this.taskExecutor = taskExecutor;
}
/**
* 获取锁资源
* @param time
* @param timeUnit
* @return
* @throws Exception
*/
public String doGetLock(long time, TimeUnit timeUnit) throws Exception{
String path = null;
if(timeUnit != null){
Future<String> future = taskExecutor.submit(new GetLock());
try {
path = future.get(time, timeUnit);
} catch (TimeoutException e) {
future.cancel(true);
}
}else {
path = waitLock();
}
return path;
}
/**
* 创建path,调用锁等待方法
* @return
* @throws Exception
*/
public String waitLock()throws Exception{
int retrycount = 1;
String path = null;
try {
path = createLockNode();
} catch (Exception e) {
if(retrycount++ < MAX_RETRY_COUNT){
createLockNode();
}else {
throw e;
}
}
return doWaitLock(path);
}
/**
* 等待锁资源(核心逻辑)
* @return 创建的路径
* @throws Exception
*/
public String doWaitLock(String path) throws Exception{
boolean getTheLock = false;
boolean doDelete = false;
try {
while (!getTheLock) {
List<String> children = getSortedChildren();
String sequenceNodeName = path.substring(baseLockName.length()+1);
int pathIndex = children.indexOf(sequenceNodeName);
if(pathIndex < 0){
throw new ZkNoNodeException("节点没有找到: " + sequenceNodeName);
}
if(pathIndex == 0){
getTheLock = true;
return path;
}else {
String previousPath = baseLockName.concat("/").concat(children.get(pathIndex-1));
final CountDownLatch latch = new CountDownLatch(1);
final IZkDataListener previousListener = new IZkDataListener() {
public void handleDataDeleted(String dataPath)
throws Exception {
latch.countDown();
}
public void handleDataChange(String dataPath,
Object data) throws Exception {
return;
}
};
try {
client.subscribeDataChanges(previousPath,
previousListener);
latch.await();
} catch (ZkNoNodeException e) {
logger.error("注册监听事件出现异常:"+e.getMessage());
} finally {
client.unsubscribeDataChanges(previousPath,
previousListener);
}
}
}
} catch (Exception e) {
doDelete = true;
throw e;
}finally {
if(doDelete){
deletePath(path);
}
}
return null;
}
/**
* 创建临时节点
* @return
* @throws Exception
*/
private String createLockNode()throws Exception {
return client.createEphemeralSequential(lockPath, null);
}
/**
* 如果该路径存在,那么就删除
* @param path
*/
public void deletePath(String path){
client.delete(path);
}
/**
* 另起一个线程去获取锁资源,主线程监听超时时间
* @author 13376
*
*/
class GetLock implements Callable<String>{
@Override
public String call() throws Exception {
return waitLock();
}
}
/**
* 顺序获取所有的子节点
* @return
* @throws Exception
*/
private List<String> getSortedChildren() throws Exception {
try {
List<String> children = client.getChildren(baseLockName);
Collections.sort(children, new Comparator<String>() {
public int compare(String lhs, String rhs) {
return getLockNodeNumber(lhs, lockName).compareTo(
getLockNodeNumber(rhs, lockName));
}
});
return children;
} catch (ZkNoNodeException e) {
client.createPersistent(baseLockName, true);
return getSortedChildren();
}
}
/**
* 获取零时节点名后面的自增长的数字
* @param str
* @param lockName
* @return
*/
private String getLockNodeNumber(String str, String lockName) {
int index = str.lastIndexOf(lockName);
if (index >= 0) {
index += lockName.length();
return index <= str.length() ? str.substring(index) : "";
}
return str;
}
}
对外暴露的接口方法:
import java.util.concurrent.TimeUnit;
import org.I0Itec.zkclient.ZkClient;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
public class ZooDistributedLock extends BaseDistributedLock implements
DistributedLock {
private static final String LOCK_NAME = "lock-";
private String path;
public ZooDistributedLock(ZkClient zkClient,String baseLockName,ThreadPoolTaskExecutor taskExecutor){
super(zkClient,baseLockName,LOCK_NAME,taskExecutor);
}
@Override
public void getLock() throws Exception {
if(!getLock(-1,null)){
throw new RuntimeException("连接异常,"+path+"路径下获取所失败");
}
}
@Override
public boolean getLock(long time, TimeUnit timeUnit) throws Exception {
path = this.doGetLock(time,timeUnit);
return path != null;
}
@Override
public void releaseLock() {
this.deletePath(path);
}
}