提交 709f4c96 编写于 作者: G guide

[refractor]use BeanPostProcesser replace ApplicationContextAware and get zk...

[refractor]use BeanPostProcesser replace ApplicationContextAware and get zk address from properties file
上级 1df258c0
import github.javaguide.HelloService;
import github.javaguide.HelloServiceImpl;
import github.javaguide.provider.ServiceProvider;
import github.javaguide.provider.ServiceProviderImpl;
import github.javaguide.remoting.transport.netty.server.NettyServer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
......@@ -13,6 +15,7 @@ public class NettyServerMain2 {
AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(NettyServerMain.class);
NettyServer nettyServer = applicationContext.getBean(NettyServer.class);
nettyServer.start();
nettyServer.publishService(helloService, HelloService.class);
ServiceProvider serviceProvider = new ServiceProviderImpl();
serviceProvider.publishService(helloService);
}
}
......@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
public class HelloServiceImpl implements HelloService {
static {
System.out.println("sdasdasdasdasd");
System.out.println("HelloServiceImpl被创建");
}
@Override
......
rpc.server.host=127.0.0.1
rpc.server.port=9998
\ No newline at end of file
rpc.zookeeper.address=127.0.0.1:2181
......@@ -11,7 +11,7 @@ import lombok.ToString;
@AllArgsConstructor
@Getter
@ToString
public enum RpcErrorMessageEnum {
public enum RpcErrorMessage {
CLIENT_CONNECT_SERVER_FAILURE("客户端连接服务端失败"),
SERVICE_INVOCATION_FAILURE("服务调用失败"),
SERVICE_CAN_NOT_BE_FOUND("没有找到指定的服务"),
......
......@@ -4,6 +4,6 @@ package github.javaguide.enumeration;
* @author shuang.kou
* @createTime 2020年06月16日 20:34:00
*/
public enum RpcMessageTypeEnum {
public enum RpcMessageType {
HEART_BEAT
}
package github.javaguide.enumeration;
public enum RpcProperties {
RPC_CONFIG_PATH("rpc.properties"),
ZK_ADDRESS("rpc.zookeeper.address");
private final String propertyValue;
RpcProperties(String propertyValue) {
this.propertyValue = propertyValue;
}
public String getPropertyValue() {
return propertyValue;
}
}
package github.javaguide.exception;
import github.javaguide.enumeration.RpcErrorMessageEnum;
import github.javaguide.enumeration.RpcErrorMessage;
/**
* @author shuang.kou
* @createTime 2020年05月12日 16:48:00
*/
public class RpcException extends RuntimeException {
public RpcException(RpcErrorMessageEnum rpcErrorMessageEnum, String detail) {
super(rpcErrorMessageEnum.getMessage() + ":" + detail);
public RpcException(RpcErrorMessage rpcErrorMessage, String detail) {
super(rpcErrorMessage.getMessage() + ":" + detail);
}
public RpcException(String message, Throwable cause) {
super(message, cause);
}
public RpcException(RpcErrorMessageEnum rpcErrorMessageEnum) {
super(rpcErrorMessageEnum.getMessage());
public RpcException(RpcErrorMessage rpcErrorMessage) {
super(rpcErrorMessage.getMessage());
}
}
......@@ -10,19 +10,19 @@ import java.util.Map;
* @createTime 2020年06月03日 15:04:00
*/
public final class SingletonFactory {
private static Map<String, Object> objectMap = new HashMap<>();
private static final Map<String, Object> OBJECT_MAP = new HashMap<>();
private SingletonFactory() {
}
public static <T> T getInstance(Class<T> c) {
String key = c.toString();
Object instance = objectMap.get(key);
Object instance = OBJECT_MAP.get(key);
synchronized (c) {
if (instance == null) {
try {
instance = c.newInstance();
objectMap.put(key, instance);
OBJECT_MAP.put(key, instance);
} catch (IllegalAccessException | InstantiationException e) {
throw new RuntimeException(e.getMessage(), e);
}
......
......@@ -9,7 +9,6 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
......@@ -25,11 +24,10 @@ public final class ThreadPoolFactoryUtils {
/**
* 通过 threadNamePrefix 来区分不同线程池(我们可以把相同 threadNamePrefix 的线程池看作是为同一业务场景服务)。
* TODO :通过信号量机制( {@link Semaphore} 满足条件)限制创建的线程池数量(线程池和线程不是越多越好)
* key: threadNamePrefix
* value: threadPool
*/
private static Map<String, ExecutorService> threadPools = new ConcurrentHashMap<>();
private static final Map<String, ExecutorService> THREAD_POOLS = new ConcurrentHashMap<>();
private ThreadPoolFactoryUtils() {
......@@ -45,12 +43,12 @@ public final class ThreadPoolFactoryUtils {
}
public static ExecutorService createCustomThreadPoolIfAbsent(CustomThreadPoolConfig customThreadPoolConfig, String threadNamePrefix, Boolean daemon) {
ExecutorService threadPool = threadPools.computeIfAbsent(threadNamePrefix, k -> createThreadPool(customThreadPoolConfig, threadNamePrefix, daemon));
ExecutorService threadPool = THREAD_POOLS.computeIfAbsent(threadNamePrefix, k -> createThreadPool(customThreadPoolConfig, threadNamePrefix, daemon));
// 如果 threadPool 被 shutdown 的话就重新创建一个
if (threadPool.isShutdown() || threadPool.isTerminated()) {
threadPools.remove(threadNamePrefix);
THREAD_POOLS.remove(threadNamePrefix);
threadPool = createThreadPool(customThreadPoolConfig, threadNamePrefix, daemon);
threadPools.put(threadNamePrefix, threadPool);
THREAD_POOLS.put(threadNamePrefix, threadPool);
}
return threadPool;
}
......@@ -60,7 +58,7 @@ public final class ThreadPoolFactoryUtils {
*/
public static void shutDownAllThreadPool() {
log.info("call shutDownAllThreadPool method");
threadPools.entrySet().parallelStream().forEach(entry -> {
THREAD_POOLS.entrySet().parallelStream().forEach(entry -> {
ExecutorService executorService = entry.getValue();
executorService.shutdown();
log.info("shut down thread pool [{}] [{}]", entry.getKey(), executorService.isTerminated());
......
package github.javaguide.utils.file;
import lombok.extern.slf4j.Slf4j;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Properties;
/**
* @author shuang.kou
* @createTime 2020年07月21日 14:25:00
**/
@Slf4j
public final class PropertiesFileUtils {
private PropertiesFileUtils() {
}
public static Properties readPropertiesFile(String fileName) {
String rootPath = Thread.currentThread().getContextClassLoader().getResource("").getPath();
String rpcConfigPath = rootPath + fileName;
Properties properties = null;
try (FileInputStream fileInputStream = new FileInputStream(rpcConfigPath)) {
properties = new Properties();
properties.load(fileInputStream);
} catch (IOException e) {
log.error("occur exception when read properties file [{}]", fileName);
}
return properties;
}
}
package github.javaguide.config;
import github.javaguide.registry.zk.util.CuratorUtils;
import github.javaguide.utils.concurrent.threadpool.ThreadPoolFactoryUtils;
import github.javaguide.utils.zk.CuratorUtils;
import lombok.extern.slf4j.Slf4j;
/**
......@@ -21,7 +21,7 @@ public class CustomShutdownHook {
public void clearAll() {
log.info("addShutdownHook for clearAll");
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
CuratorUtils.clearRegistry();
CuratorUtils.clearRegistry(CuratorUtils.getZkClient());
ThreadPoolFactoryUtils.shutDownAllThreadPool();
}));
}
......
......@@ -14,7 +14,7 @@ public interface ServiceProvider {
* @param service 服务实例对象
* @param serviceClass 服务实例对象实现的接口类
*/
void addServiceProvider(Object service, Class<?> serviceClass);
void addServiceProvider(Object service, Class<?> serviceClass);
/**
* 获取服务实例对象
......@@ -23,4 +23,11 @@ public interface ServiceProvider {
* @return 服务实例对象
*/
Object getServiceProvider(String serviceName);
/**
* 发布服务
*
* @param service 服务实例对象
*/
void publishService(Object service);
}
package github.javaguide.provider;
import github.javaguide.enumeration.RpcErrorMessageEnum;
import github.javaguide.enumeration.RpcErrorMessage;
import github.javaguide.exception.RpcException;
import github.javaguide.registry.ServiceRegistry;
import github.javaguide.registry.zk.ZkServiceRegistry;
import github.javaguide.remoting.transport.netty.server.NettyServer;
import lombok.extern.slf4j.Slf4j;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
......@@ -23,29 +29,40 @@ public class ServiceProviderImpl implements ServiceProvider {
* key:service/interface name
* value:service
*/
private static final Map<String, Object> serviceMap = new ConcurrentHashMap<>();
private static final Set<String> registeredService = ConcurrentHashMap.newKeySet();
private static final Map<String, Object> SERVICE_MAP = new ConcurrentHashMap<>();
private static final Set<String> REGISTERED_SERVICE = ConcurrentHashMap.newKeySet();
private final ServiceRegistry serviceRegistry = new ZkServiceRegistry();
/**
* note:可以修改为扫描注解注册
*/
@Override
public void addServiceProvider(Object service, Class<?> serviceClass) {
String serviceName = serviceClass.getCanonicalName();
if (registeredService.contains(serviceName)) {
if (REGISTERED_SERVICE.contains(serviceName)) {
return;
}
registeredService.add(serviceName);
serviceMap.put(serviceName, service);
REGISTERED_SERVICE.add(serviceName);
SERVICE_MAP.put(serviceName, service);
log.info("Add service: {} and interfaces:{}", serviceName, service.getClass().getInterfaces());
}
@Override
public Object getServiceProvider(String serviceName) {
Object service = serviceMap.get(serviceName);
Object service = SERVICE_MAP.get(serviceName);
if (null == service) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_CAN_NOT_BE_FOUND);
throw new RpcException(RpcErrorMessage.SERVICE_CAN_NOT_BE_FOUND);
}
return service;
}
public void publishService(Object service) {
try {
String host = InetAddress.getLocalHost().getHostAddress();
Class<?> anInterface = service.getClass().getInterfaces()[0];
this.addServiceProvider(service, anInterface);
serviceRegistry.registerService(anInterface.getCanonicalName(), new InetSocketAddress(host, NettyServer.PORT));
} catch (UnknownHostException e) {
log.error("occur exception when getHostAddress", e);
}
}
}
package github.javaguide.registry;
package github.javaguide.registry.zk;
import github.javaguide.loadbalance.LoadBalance;
import github.javaguide.loadbalance.RandomLoadBalance;
import github.javaguide.utils.zk.CuratorUtils;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.registry.zk.util.CuratorUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.curator.framework.CuratorFramework;
import java.net.InetSocketAddress;
import java.util.List;
......@@ -25,7 +27,8 @@ public class ZkServiceDiscovery implements ServiceDiscovery {
@Override
public InetSocketAddress lookupService(String serviceName) {
// 这里直接去了第一个找到的服务地址,eg:127.0.0.1:9999
List<String> serviceUrlList = CuratorUtils.getChildrenNodes(serviceName);
CuratorFramework zkClient = CuratorUtils.getZkClient();
List<String> serviceUrlList = CuratorUtils.getChildrenNodes(zkClient, serviceName);
// 负载均衡
String targetServiceUrl = loadBalance.selectServiceAddress(serviceUrlList);
log.info("成功找到服务地址:[{}]", targetServiceUrl);
......
package github.javaguide.registry;
package github.javaguide.registry.zk;
import github.javaguide.utils.zk.CuratorUtils;
import github.javaguide.registry.ServiceRegistry;
import github.javaguide.registry.zk.util.CuratorUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.curator.framework.CuratorFramework;
import java.net.InetSocketAddress;
......@@ -18,6 +20,7 @@ public class ZkServiceRegistry implements ServiceRegistry {
public void registerService(String serviceName, InetSocketAddress inetSocketAddress) {
//根节点下注册子节点:服务
String servicePath = CuratorUtils.ZK_REGISTER_ROOT_PATH + "/" + serviceName + inetSocketAddress.toString();
CuratorUtils.createPersistentNode(servicePath);
CuratorFramework zkClient = CuratorUtils.getZkClient();
CuratorUtils.createPersistentNode(zkClient, servicePath);
}
}
package github.javaguide.utils.zk;
package github.javaguide.registry.zk.util;
import github.javaguide.enumeration.RpcProperties;
import github.javaguide.exception.RpcException;
import github.javaguide.utils.file.PropertiesFileUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.curator.RetryPolicy;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.CuratorFrameworkFactory;
import org.apache.curator.framework.imps.CuratorFrameworkState;
import org.apache.curator.framework.recipes.cache.PathChildrenCache;
import org.apache.curator.framework.recipes.cache.PathChildrenCacheListener;
import org.apache.curator.retry.ExponentialBackoffRetry;
......@@ -12,6 +15,7 @@ import org.apache.zookeeper.CreateMode;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
......@@ -23,17 +27,14 @@ import java.util.concurrent.ConcurrentHashMap;
*/
@Slf4j
public final class CuratorUtils {
private static final int BASE_SLEEP_TIME = 1000;
private static final int MAX_RETRIES = 3;
private static final String CONNECT_STRING = "127.0.0.1:2181";
private static String defaultZookeeperAddress = "127.0.0.1:2181";
public static final String ZK_REGISTER_ROOT_PATH = "/my-rpc";
private static final Map<String, List<String>> serviceAddressMap = new ConcurrentHashMap<>();
private static final Set<String> registeredPathSet = ConcurrentHashMap.newKeySet();
private static final CuratorFramework zkClient;
static {
zkClient = getZkClient();
}
private static final Map<String, List<String>> SERVICE_ADDRESS_MAP = new ConcurrentHashMap<>();
private static final Set<String> REGISTERED_PATH_SET = ConcurrentHashMap.newKeySet();
private static CuratorFramework zkClient;
private CuratorUtils() {
}
......@@ -43,16 +44,16 @@ public final class CuratorUtils {
*
* @param path 节点路径
*/
public static void createPersistentNode(String path) {
public static void createPersistentNode(CuratorFramework zkClient, String path) {
try {
if (registeredPathSet.contains(path) || zkClient.checkExists().forPath(path) != null) {
if (REGISTERED_PATH_SET.contains(path) || zkClient.checkExists().forPath(path) != null) {
log.info("节点已经存在,节点为:[{}]", path);
} else {
//eg: /my-rpc/github.javaguide.HelloService/127.0.0.1:9999
zkClient.create().creatingParentsIfNeeded().withMode(CreateMode.PERSISTENT).forPath(path);
log.info("节点创建成功,节点为:[{}]", path);
}
registeredPathSet.add(path);
REGISTERED_PATH_SET.add(path);
} catch (Exception e) {
throw new RpcException(e.getMessage(), e.getCause());
}
......@@ -64,16 +65,16 @@ public final class CuratorUtils {
* @param serviceName 服务对象接口名 eg:github.javaguide.HelloService
* @return 指定字节下的所有子节点
*/
public static List<String> getChildrenNodes(String serviceName) {
if (serviceAddressMap.containsKey(serviceName)) {
return serviceAddressMap.get(serviceName);
public static List<String> getChildrenNodes(CuratorFramework zkClient, String serviceName) {
if (SERVICE_ADDRESS_MAP.containsKey(serviceName)) {
return SERVICE_ADDRESS_MAP.get(serviceName);
}
List<String> result;
String servicePath = ZK_REGISTER_ROOT_PATH + "/" + serviceName;
try {
result = zkClient.getChildren().forPath(servicePath);
serviceAddressMap.put(serviceName, result);
registerWatcher(serviceName);
SERVICE_ADDRESS_MAP.put(serviceName, result);
registerWatcher(serviceName, zkClient);
} catch (Exception e) {
throw new RpcException(e.getMessage(), e.getCause());
}
......@@ -83,27 +84,36 @@ public final class CuratorUtils {
/**
* 清空注册中心的数据
*/
public static void clearRegistry() {
registeredPathSet.stream().parallel().forEach(p -> {
public static void clearRegistry(CuratorFramework zkClient) {
REGISTERED_PATH_SET.stream().parallel().forEach(p -> {
try {
zkClient.delete().forPath(p);
} catch (Exception e) {
throw new RpcException(e.getMessage(), e.getCause());
}
});
log.info("服务端(Provider)所有注册的服务都被清空:[{}]", registeredPathSet.toString());
log.info("服务端(Provider)所有注册的服务都被清空:[{}]", REGISTERED_PATH_SET.toString());
}
private static CuratorFramework getZkClient() {
// 重试策略。重试3次,并且会增加重试之间的睡眠时间。
public static CuratorFramework getZkClient() {
// check if user has set zk address
Properties properties = PropertiesFileUtils.readPropertiesFile(RpcProperties.RPC_CONFIG_PATH.getPropertyValue());
if (properties != null) {
defaultZookeeperAddress = properties.getProperty(RpcProperties.ZK_ADDRESS.getPropertyValue());
}
// if zkClient has been started, return directly
if (zkClient != null && zkClient.getState() == CuratorFrameworkState.STARTED) {
return zkClient;
}
// Retry strategy. Retry 3 times, and will increase the sleep time between retries.
RetryPolicy retryPolicy = new ExponentialBackoffRetry(BASE_SLEEP_TIME, MAX_RETRIES);
CuratorFramework curatorFramework = CuratorFrameworkFactory.builder()
//要连接的服务器(可以是服务器列表)
.connectString(CONNECT_STRING)
zkClient = CuratorFrameworkFactory.builder()
// the server to connect to (can be a server list)
.connectString(defaultZookeeperAddress)
.retryPolicy(retryPolicy)
.build();
curatorFramework.start();
return curatorFramework;
zkClient.start();
return zkClient;
}
/**
......@@ -111,12 +121,12 @@ public final class CuratorUtils {
*
* @param serviceName 服务对象接口名 eg:github.javaguide.HelloService
*/
private static void registerWatcher(String serviceName) {
private static void registerWatcher(String serviceName, CuratorFramework zkClient) {
String servicePath = ZK_REGISTER_ROOT_PATH + "/" + serviceName;
PathChildrenCache pathChildrenCache = new PathChildrenCache(CuratorUtils.zkClient, servicePath, true);
PathChildrenCache pathChildrenCache = new PathChildrenCache(zkClient, servicePath, true);
PathChildrenCacheListener pathChildrenCacheListener = (curatorFramework, pathChildrenCacheEvent) -> {
List<String> serviceAddresses = curatorFramework.getChildren().forPath(servicePath);
serviceAddressMap.put(serviceName, serviceAddresses);
SERVICE_ADDRESS_MAP.put(serviceName, serviceAddresses);
};
pathChildrenCache.getListenable().addListener(pathChildrenCacheListener);
try {
......
package github.javaguide.remoting.dto;
import github.javaguide.enumeration.RpcErrorMessageEnum;
import github.javaguide.enumeration.RpcErrorMessage;
import github.javaguide.enumeration.RpcResponseCode;
import github.javaguide.exception.RpcException;
import lombok.extern.slf4j.Slf4j;
......@@ -21,15 +21,15 @@ public final class RpcMessageChecker {
public static void check(RpcResponse rpcResponse, RpcRequest rpcRequest) {
if (rpcResponse == null) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
throw new RpcException(RpcErrorMessage.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
}
if (!rpcRequest.getRequestId().equals(rpcResponse.getRequestId())) {
throw new RpcException(RpcErrorMessageEnum.REQUEST_NOT_MATCH_RESPONSE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
throw new RpcException(RpcErrorMessage.REQUEST_NOT_MATCH_RESPONSE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
}
if (rpcResponse.getCode() == null || !rpcResponse.getCode().equals(RpcResponseCode.SUCCESS.getCode())) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
throw new RpcException(RpcErrorMessage.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
}
}
}
package github.javaguide.remoting.dto;
import github.javaguide.enumeration.RpcMessageTypeEnum;
import github.javaguide.enumeration.RpcMessageType;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
......@@ -25,5 +25,5 @@ public class RpcRequest implements Serializable {
private String methodName;
private Object[] parameters;
private Class<?>[] paramTypes;
private RpcMessageTypeEnum rpcMessageTypeEnum;
private RpcMessageType rpcMessageType;
}
package github.javaguide.remoting.handler;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.enumeration.RpcResponseCode;
......@@ -19,7 +20,11 @@ import java.lang.reflect.Method;
*/
@Slf4j
public class RpcRequestHandler {
private static ServiceProvider serviceProvider = new ServiceProviderImpl();
private final ServiceProvider serviceProvider;
public RpcRequestHandler() {
serviceProvider = SingletonFactory.getInstance(ServiceProviderImpl.class);
}
/**
* 处理 rpcRequest :调用对应的方法,然后返回方法执行结果
......
......@@ -15,36 +15,37 @@ import java.util.concurrent.ConcurrentHashMap;
* @createTime 2020年05月29日 16:36:00
*/
@Slf4j
public final class ChannelProvider {
public class ChannelProvider {
private static final Map<String, Channel> channels = new ConcurrentHashMap<>();
private static final NettyClient nettyClient = SingletonFactory.getInstance(NettyClient.class);
private ChannelProvider() {
private final Map<String, Channel> channelMap;
private final NettyClient nettyClient;
public ChannelProvider() {
channelMap = new ConcurrentHashMap<>();
nettyClient = SingletonFactory.getInstance(NettyClient.class);
}
public static Channel get(InetSocketAddress inetSocketAddress) {
public Channel get(InetSocketAddress inetSocketAddress) {
String key = inetSocketAddress.toString();
// determine if there is a connection for the corresponding address
if (channels.containsKey(key)) {
Channel channel = channels.get(key);
if (channelMap.containsKey(key)) {
Channel channel = channelMap.get(key);
// if so, determine if the connection is available, and if so, get it directly
if (channel != null && channel.isActive()) {
return channel;
} else {
channels.remove(key);
channelMap.remove(key);
}
}
// otherwise, reconnect to get the Channel
Channel channel = nettyClient.doConnect(inetSocketAddress);
channels.put(key, channel);
channelMap.put(key, channel);
return channel;
}
public static void remove(InetSocketAddress inetSocketAddress) {
public void remove(InetSocketAddress inetSocketAddress) {
String key = inetSocketAddress.toString();
channels.remove(key);
log.info("Channel map size :[{}]", channels.size());
channelMap.remove(key);
log.info("Channel map size :[{}]", channelMap.size());
}
}
......@@ -32,11 +32,11 @@ import java.util.concurrent.TimeUnit;
*/
@Slf4j
public final class NettyClient {
private static final Bootstrap bootstrap;
private static final EventLoopGroup eventLoopGroup;
private final Bootstrap bootstrap;
private final EventLoopGroup eventLoopGroup;
// initialize resources such as EventLoopGroup, Bootstrap
static {
public NettyClient() {
eventLoopGroup = new NioEventLoopGroup();
bootstrap = new Bootstrap();
KryoSerializer kryoSerializer = new KryoSerializer();
......@@ -63,6 +63,7 @@ public final class NettyClient {
});
}
/**
* connect server and get the channel ,so that you can send rpc message to server
*
......
package github.javaguide.remoting.transport.netty.client;
import github.javaguide.enumeration.RpcMessageTypeEnum;
import github.javaguide.enumeration.RpcMessageType;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
......@@ -17,7 +17,7 @@ import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
/**
* 自定义客户端 ChannelHandler 来处理服务端发过来的数据
* Customize the client ChannelHandler to process the data sent by the server
*
* <p>
* 如果继承自 SimpleChannelInboundHandler 的话就不要考虑 ByteBuf 的释放 ,{@link SimpleChannelInboundHandler} 内部的
......@@ -29,21 +29,24 @@ import java.net.InetSocketAddress;
@Slf4j
public class NettyClientHandler extends ChannelInboundHandlerAdapter {
private final UnprocessedRequests unprocessedRequests;
private final ChannelProvider channelProvider;
public NettyClientHandler() {
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
}
/**
* 读取服务端传输的消息
* Read the message transmitted by the server
*/
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
log.info("client receive msg: [{}]", msg);
RpcResponse rpcResponse = (RpcResponse) msg;
unprocessedRequests.complete(rpcResponse);
if (msg instanceof RpcResponse) {
RpcResponse<Object> rpcResponse = (RpcResponse<Object>) msg;
unprocessedRequests.complete(rpcResponse);
}
} finally {
ReferenceCountUtil.release(msg);
}
......@@ -55,8 +58,8 @@ public class NettyClientHandler extends ChannelInboundHandlerAdapter {
IdleState state = ((IdleStateEvent) evt).state();
if (state == IdleState.WRITER_IDLE) {
log.info("write idle happen [{}]", ctx.channel().remoteAddress());
Channel channel = ChannelProvider.get((InetSocketAddress) ctx.channel().remoteAddress());
RpcRequest rpcRequest = RpcRequest.builder().rpcMessageTypeEnum(RpcMessageTypeEnum.HEART_BEAT).build();
Channel channel = channelProvider.get((InetSocketAddress) ctx.channel().remoteAddress());
RpcRequest rpcRequest = RpcRequest.builder().rpcMessageType(RpcMessageType.HEART_BEAT).build();
channel.writeAndFlush(rpcRequest).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} else {
......@@ -65,7 +68,7 @@ public class NettyClientHandler extends ChannelInboundHandlerAdapter {
}
/**
* 处理客户端消息发生异常的时候被调用
* Called when an exception occurs in processing a client message
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
......
......@@ -2,7 +2,7 @@ package github.javaguide.remoting.transport.netty.client;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.registry.ZkServiceDiscovery;
import github.javaguide.registry.zk.ZkServiceDiscovery;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.ClientTransport;
......@@ -14,7 +14,7 @@ import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture;
/**
* 基于 Netty 传输 RpcRequest。
* transport rpcRequest based on netty.
*
* @author shuang.kou
* @createTime 2020年05月29日 11:34:00
......@@ -23,20 +23,22 @@ import java.util.concurrent.CompletableFuture;
public class NettyClientTransport implements ClientTransport {
private final ServiceDiscovery serviceDiscovery;
private final UnprocessedRequests unprocessedRequests;
private final ChannelProvider channelProvider;
public NettyClientTransport() {
this.serviceDiscovery = new ZkServiceDiscovery();
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
}
@Override
public CompletableFuture<RpcResponse> sendRpcRequest(RpcRequest rpcRequest) {
// 构建返回值
CompletableFuture<RpcResponse> resultFuture = new CompletableFuture<>();
public CompletableFuture<RpcResponse<Object>> sendRpcRequest(RpcRequest rpcRequest) {
// build return value
CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcRequest.getInterfaceName());
Channel channel = ChannelProvider.get(inetSocketAddress);
Channel channel = channelProvider.get(inetSocketAddress);
if (channel != null && channel.isActive()) {
// 放入未处理的请求
// put unprocessed request
unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
channel.writeAndFlush(rpcRequest).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
......
......@@ -7,20 +7,20 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
/**
* 未处理的请求。
* unprocessed requests by the server.
*
* @author shuang.kou
* @createTime 2020年06月04日 17:30:00
*/
public class UnprocessedRequests {
private static Map<String, CompletableFuture<RpcResponse>> unprocessedResponseFutures = new ConcurrentHashMap<>();
private static final Map<String, CompletableFuture<RpcResponse<Object>>> UNPROCESSED_RESPONSE_FUTURES = new ConcurrentHashMap<>();
public void put(String requestId, CompletableFuture<RpcResponse> future) {
unprocessedResponseFutures.put(requestId, future);
public void put(String requestId, CompletableFuture<RpcResponse<Object>> future) {
UNPROCESSED_RESPONSE_FUTURES.put(requestId, future);
}
public void complete(RpcResponse rpcResponse) {
CompletableFuture<RpcResponse> future = unprocessedResponseFutures.remove(rpcResponse.getRequestId());
public void complete(RpcResponse<Object> rpcResponse) {
CompletableFuture<RpcResponse<Object>> future = UNPROCESSED_RESPONSE_FUTURES.remove(rpcResponse.getRequestId());
if (null != future) {
future.complete(rpcResponse);
} else {
......
......@@ -19,8 +19,8 @@ import java.util.List;
@Slf4j
public class NettyKryoDecoder extends ByteToMessageDecoder {
private Serializer serializer;
private Class<?> genericClass;
private final Serializer serializer;
private final Class<?> genericClass;
/**
* Netty传输的消息长度也就是对象序列化后对应的字节数组的大小,存储在 ByteBuf 头部
......
......@@ -16,8 +16,8 @@ import lombok.AllArgsConstructor;
*/
@AllArgsConstructor
public class NettyKryoEncoder extends MessageToByteEncoder<Object> {
private Serializer serializer;
private Class<?> genericClass;
private final Serializer serializer;
private final Class<?> genericClass;
/**
* 将对象转换为字节码然后写入到 ByteBuf 对象中
......
package github.javaguide.remoting.transport.netty.server;
import github.javaguide.annotation.RpcService;
import github.javaguide.config.CustomShutdownHook;
import github.javaguide.provider.ServiceProvider;
import github.javaguide.provider.ServiceProviderImpl;
import github.javaguide.registry.ServiceRegistry;
import github.javaguide.registry.ZkServiceRegistry;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.netty.codec.kyro.NettyKryoDecoder;
......@@ -22,17 +17,12 @@ import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.IdleStateHandler;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.PropertySource;
import org.springframework.stereotype.Component;
import java.net.InetSocketAddress;
import java.util.Map;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;
/**
......@@ -44,23 +34,14 @@ import java.util.concurrent.TimeUnit;
*/
@Slf4j
@Component
@PropertySource("classpath:rpc.properties")
public class NettyServer implements InitializingBean, ApplicationContextAware {
@Value("${rpc.server.host}")
private String host;
@Value("${rpc.server.port}")
private int port;
public class NettyServer implements InitializingBean {
private final KryoSerializer kryoSerializer = new KryoSerializer();
private final ServiceRegistry serviceRegistry = new ZkServiceRegistry();
private final ServiceProvider serviceProvider = new ServiceProviderImpl();
public void publishService(Object service, Class<?> serviceClass) {
serviceProvider.addServiceProvider(service, serviceClass);
serviceRegistry.registerService(serviceClass.getCanonicalName(), new InetSocketAddress(host, port));
}
public static final int PORT = 9998;
@SneakyThrows
public void start() {
String host = InetAddress.getLocalHost().getHostAddress();
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
......@@ -87,7 +68,7 @@ public class NettyServer implements InitializingBean, ApplicationContextAware {
});
// 绑定端口,同步等待绑定成功
ChannelFuture f = b.bind(host, port).sync();
ChannelFuture f = b.bind(host, PORT).sync();
// 等待服务端监听端口关闭
f.channel().closeFuture().sync();
} catch (InterruptedException e) {
......@@ -107,10 +88,4 @@ public class NettyServer implements InitializingBean, ApplicationContextAware {
CustomShutdownHook.getCustomShutdownHook().clearAll();
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
// 获得所有被 RpcService 注解的类
Map<String, Object> registeredBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
registeredBeanMap.values().forEach(o -> publishService(o, o.getClass().getInterfaces()[0]));
}
}
package github.javaguide.remoting.transport.netty.server;
import github.javaguide.enumeration.RpcMessageTypeEnum;
import github.javaguide.enumeration.RpcMessageType;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.handler.RpcRequestHandler;
import github.javaguide.remoting.dto.RpcRequest;
......@@ -37,7 +37,7 @@ public class NettyServerHandler extends ChannelInboundHandlerAdapter {
try {
log.info("server receive msg: [{}] ", msg);
RpcRequest rpcRequest = (RpcRequest) msg;
if (rpcRequest.getRpcMessageTypeEnum() == RpcMessageTypeEnum.HEART_BEAT) {
if (rpcRequest.getRpcMessageType() == RpcMessageType.HEART_BEAT) {
log.info("receive heat beat msg from client");
return;
}
......
package github.javaguide.remoting.transport.netty.server;
import github.javaguide.annotation.RpcService;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.provider.ServiceProvider;
import github.javaguide.provider.ServiceProviderImpl;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.stereotype.Component;
/**
* call this method before creating the bean to see if the class is annotated
*
* @author shuang.kou
* @createTime 2020年07月14日 16:42:00
*/
@Component
@Slf4j
public class SpringBeanPostProcessor implements BeanPostProcessor {
private final ServiceProvider serviceProvider;
public SpringBeanPostProcessor() {
serviceProvider = SingletonFactory.getInstance(ServiceProviderImpl.class);
}
@SneakyThrows
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if (bean.getClass().isAnnotationPresent(RpcService.class)) {
log.info("[{}] is annotated with [{}]", bean.getClass().getName(), RpcService.class.getCanonicalName());
serviceProvider.publishService(bean);
}
return bean;
}
}
......@@ -2,7 +2,7 @@ package github.javaguide.remoting.transport.socket;
import github.javaguide.exception.RpcException;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.registry.ZkServiceDiscovery;
import github.javaguide.registry.zk.ZkServiceDiscovery;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.transport.ClientTransport;
import lombok.AllArgsConstructor;
......
......@@ -17,8 +17,8 @@ import java.net.Socket;
*/
@Slf4j
public class SocketRpcRequestHandlerRunnable implements Runnable {
private Socket socket;
private RpcRequestHandler rpcRequestHandler;
private final Socket socket;
private final RpcRequestHandler rpcRequestHandler;
public SocketRpcRequestHandlerRunnable(Socket socket) {
......
......@@ -4,7 +4,7 @@ import github.javaguide.config.CustomShutdownHook;
import github.javaguide.provider.ServiceProvider;
import github.javaguide.provider.ServiceProviderImpl;
import github.javaguide.registry.ServiceRegistry;
import github.javaguide.registry.ZkServiceRegistry;
import github.javaguide.registry.zk.ZkServiceRegistry;
import github.javaguide.utils.concurrent.threadpool.ThreadPoolFactoryUtils;
import lombok.extern.slf4j.Slf4j;
......
package github.javaguide.registry;
import github.javaguide.registry.zk.ZkServiceDiscovery;
import github.javaguide.registry.zk.ZkServiceRegistry;
import org.junit.jupiter.api.Test;
import java.net.InetSocketAddress;
......@@ -17,9 +19,9 @@ class ZkServiceRegistryTest {
void should_register_service_successful_and_lookup_service_by_service_name() {
ServiceRegistry zkServiceRegistry = new ZkServiceRegistry();
InetSocketAddress givenInetSocketAddress = new InetSocketAddress("127.0.0.1", 9333);
zkServiceRegistry.registerService("github.javaguide.registry.ZkServiceRegistry", givenInetSocketAddress);
zkServiceRegistry.registerService("github.javaguide.registry.zk.ZkServiceRegistry", givenInetSocketAddress);
ServiceDiscovery zkServiceDiscovery = new ZkServiceDiscovery();
InetSocketAddress acquiredInetSocketAddress = zkServiceDiscovery.lookupService("github.javaguide.registry.ZkServiceRegistry");
InetSocketAddress acquiredInetSocketAddress = zkServiceDiscovery.lookupService("github.javaguide.registry.zk.ZkServiceRegistry");
assertEquals(givenInetSocketAddress.toString(), acquiredInetSocketAddress.toString());
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册