提交 513bac7c 编写于 作者: S shuang.kou

[v2.0]refractor move registry logic to request handler

上级 fd02d72b
# guide-rpc-framework # guide-rpc-framework
...@@ -14,6 +14,7 @@ public class NettyClientMain { ...@@ -14,6 +14,7 @@ public class NettyClientMain {
RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient); RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient);
HelloService helloService = rpcClientProxy.getProxy(HelloService.class); HelloService helloService = rpcClientProxy.getProxy(HelloService.class);
String hello = helloService.hello(new Hello("111", "222")); String hello = helloService.hello(new Hello("111", "222"));
System.out.println(hello); //如需使用 assert 断言,需要在 VM options 添加参数:-ea
assert "Hello description is 222".equals(hello);
} }
} }
...@@ -2,7 +2,6 @@ package github.javaguide; ...@@ -2,7 +2,6 @@ package github.javaguide;
import github.javaguide.transport.RpcClient; import github.javaguide.transport.RpcClient;
import github.javaguide.transport.RpcClientProxy; import github.javaguide.transport.RpcClientProxy;
import github.javaguide.transport.netty.client.NettyRpcClient;
import github.javaguide.transport.socket.SocketRpcClient; import github.javaguide.transport.socket.SocketRpcClient;
/** /**
......
...@@ -8,7 +8,7 @@ import org.slf4j.LoggerFactory; ...@@ -8,7 +8,7 @@ import org.slf4j.LoggerFactory;
* @createTime 2020年05月12日 17:36:00 * @createTime 2020年05月12日 17:36:00
*/ */
public class HelloServiceImpl2 { public class HelloServiceImpl2 {
private static final Logger logger = LoggerFactory.getLogger(HelloServiceImpl.class); private static final Logger logger = LoggerFactory.getLogger(HelloServiceImpl2.class);
public String hello(Hello hello) { public String hello(Hello hello) {
logger.info("HelloServiceImpl2收到: {}.", hello.getMessage()); logger.info("HelloServiceImpl2收到: {}.", hello.getMessage());
......
...@@ -45,7 +45,7 @@ public class KryoSerializer implements Serializer { ...@@ -45,7 +45,7 @@ public class KryoSerializer implements Serializer {
kryoThreadLocal.remove(); kryoThreadLocal.remove();
return output.toBytes(); return output.toBytes();
} catch (Exception e) { } catch (Exception e) {
logger.error("occur github.javaguide.exception when serialize:", e); logger.error("occur exception when serialize:", e);
throw new SerializeException("序列化失败"); throw new SerializeException("序列化失败");
} }
} }
...@@ -60,7 +60,7 @@ public class KryoSerializer implements Serializer { ...@@ -60,7 +60,7 @@ public class KryoSerializer implements Serializer {
kryoThreadLocal.remove(); kryoThreadLocal.remove();
return clazz.cast(o); return clazz.cast(o);
} catch (Exception e) { } catch (Exception e) {
logger.error("occur github.javaguide.exception when deserialize:", e); logger.error("occur exception when deserialize:", e);
throw new SerializeException("反序列化失败"); throw new SerializeException("反序列化失败");
} }
} }
......
...@@ -3,6 +3,8 @@ package github.javaguide.transport; ...@@ -3,6 +3,8 @@ package github.javaguide.transport;
import github.javaguide.dto.RpcRequest; import github.javaguide.dto.RpcRequest;
/** /**
* 实现了 RpcClient 接口的对象需要具有发送 RpcRequest 的能力
*
* @author shuang.kou * @author shuang.kou
* @createTime 2020年05月25日 17:02:00 * @createTime 2020年05月25日 17:02:00
*/ */
......
...@@ -10,22 +10,33 @@ import java.lang.reflect.Proxy; ...@@ -10,22 +10,33 @@ import java.lang.reflect.Proxy;
import java.util.UUID; import java.util.UUID;
/** /**
* 动态代理类。当动态代理对象调用一个方法的时候,实际调用的是下面的 invoke 方法
*
* @author shuang.kou * @author shuang.kou
* @createTime 2020年05月10日 19:01:00 * @createTime 2020年05月10日 19:01:00
*/ */
public class RpcClientProxy implements InvocationHandler { public class RpcClientProxy implements InvocationHandler {
private static final Logger logger = LoggerFactory.getLogger(RpcClientProxy.class); private static final Logger logger = LoggerFactory.getLogger(RpcClientProxy.class);
private RpcClient rpcClient; /**
* 用于发送请求给服务端,对应socket和netty两种实现方式
*/
private final RpcClient rpcClient;
public RpcClientProxy(RpcClient rpcClient) { public RpcClientProxy(RpcClient rpcClient) {
this.rpcClient = rpcClient; this.rpcClient = rpcClient;
} }
/**
* 通过 Proxy.newProxyInstance() 方法获取某个类的代理对象
*/
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T getProxy(Class<T> clazz) { public <T> T getProxy(Class<T> clazz) {
return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class<?>[]{clazz}, this); return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class<?>[]{clazz}, this);
} }
/**
* 当你使用代理对象调用方法的时候实际会调用到这个方法。代理对象就是你通过上面的 getProxy 方法获取到的对象。
*/
@Override @Override
public Object invoke(Object proxy, Method method, Object[] args) { public Object invoke(Object proxy, Method method, Object[] args) {
logger.info("Call invoke method and invoked method: {}", method.getName()); logger.info("Call invoke method and invoked method: {}", method.getName());
......
...@@ -3,6 +3,8 @@ package github.javaguide.transport; ...@@ -3,6 +3,8 @@ package github.javaguide.transport;
import github.javaguide.dto.RpcRequest; import github.javaguide.dto.RpcRequest;
import github.javaguide.dto.RpcResponse; import github.javaguide.dto.RpcResponse;
import github.javaguide.enumeration.RpcResponseCode; import github.javaguide.enumeration.RpcResponseCode;
import github.javaguide.registry.DefaultServiceRegistry;
import github.javaguide.registry.ServiceRegistry;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -15,18 +17,31 @@ import java.lang.reflect.Method; ...@@ -15,18 +17,31 @@ import java.lang.reflect.Method;
*/ */
public class RpcRequestHandler { public class RpcRequestHandler {
private static final Logger logger = LoggerFactory.getLogger(RpcRequestHandler.class); private static final Logger logger = LoggerFactory.getLogger(RpcRequestHandler.class);
private static final ServiceRegistry serviceRegistry;
public Object handle(RpcRequest rpcRequest, Object service) { static {
serviceRegistry = new DefaultServiceRegistry();
}
/**
* 处理 rpcRequest 然后返回方法执行结果
*/
public Object handle(RpcRequest rpcRequest) {
Object result = null; Object result = null;
//通过注册中心获取到目标类(客户端需要调用类)
Object service = serviceRegistry.getService(rpcRequest.getInterfaceName());
try { try {
result = invokeTargetMethod(rpcRequest, service); result = invokeTargetMethod(rpcRequest, service);
logger.info("service:{} successful invoke method:{}", rpcRequest.getInterfaceName(), rpcRequest.getMethodName()); logger.info("service:{} successful invoke method:{}", rpcRequest.getInterfaceName(), rpcRequest.getMethodName());
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
logger.error("occur github.javaguide.exception", e); logger.error("occur exception", e);
} }
return result; return result;
} }
/**
* 根据 rpcRequest 和 service 对象特定的方法并返回结果
*/
private Object invokeTargetMethod(RpcRequest rpcRequest, Object service) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException { private Object invokeTargetMethod(RpcRequest rpcRequest, Object service) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
Method method = service.getClass().getMethod(rpcRequest.getMethodName(), rpcRequest.getParamTypes()); Method method = service.getClass().getMethod(rpcRequest.getMethodName(), rpcRequest.getParamTypes());
if (null == method) { if (null == method) {
......
...@@ -49,7 +49,7 @@ public class NettyClientHandler extends ChannelInboundHandlerAdapter { ...@@ -49,7 +49,7 @@ public class NettyClientHandler extends ChannelInboundHandlerAdapter {
*/ */
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.error("client catch github.javaguide.exception:", cause); logger.error("client catch exception:", cause);
cause.printStackTrace(); cause.printStackTrace();
ctx.close(); ctx.close();
} }
......
...@@ -81,12 +81,13 @@ public class NettyRpcClient implements RpcClient { ...@@ -81,12 +81,13 @@ public class NettyRpcClient implements RpcClient {
futureChannel.closeFuture().sync(); futureChannel.closeFuture().sync();
AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse" + rpcRequest.getRequestId()); AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse" + rpcRequest.getRequestId());
RpcResponse rpcResponse = futureChannel.attr(key).get(); RpcResponse rpcResponse = futureChannel.attr(key).get();
logger.info("client get rpcResponse from channel:{}", rpcResponse);
//校验 RpcResponse 和 RpcRequest //校验 RpcResponse 和 RpcRequest
RpcMessageChecker.check(rpcResponse, rpcRequest); RpcMessageChecker.check(rpcResponse, rpcRequest);
return rpcResponse.getData(); return rpcResponse.getData();
} }
} catch (InterruptedException e) { } catch (InterruptedException e) {
logger.error("occur github.javaguide.exception when connect server:", e); logger.error("occur exception when connect server:", e);
} }
return null; return null;
} }
......
...@@ -27,7 +27,7 @@ import org.slf4j.LoggerFactory; ...@@ -27,7 +27,7 @@ import org.slf4j.LoggerFactory;
public class NettyRpcServer { public class NettyRpcServer {
private static final Logger logger = LoggerFactory.getLogger(NettyRpcServer.class); private static final Logger logger = LoggerFactory.getLogger(NettyRpcServer.class);
private final int port; private final int port;
private KryoSerializer kryoSerializer; private final KryoSerializer kryoSerializer;
public NettyRpcServer(int port) { public NettyRpcServer(int port) {
this.port = port; this.port = port;
...@@ -60,7 +60,7 @@ public class NettyRpcServer { ...@@ -60,7 +60,7 @@ public class NettyRpcServer {
// 等待服务端监听端口关闭 // 等待服务端监听端口关闭
f.channel().closeFuture().sync(); f.channel().closeFuture().sync();
} catch (InterruptedException e) { } catch (InterruptedException e) {
logger.error("occur github.javaguide.exception when start server:", e); logger.error("occur exception when start server:", e);
} finally { } finally {
bossGroup.shutdownGracefully(); bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully(); workerGroup.shutdownGracefully();
......
...@@ -2,8 +2,6 @@ package github.javaguide.transport.netty.server; ...@@ -2,8 +2,6 @@ package github.javaguide.transport.netty.server;
import github.javaguide.dto.RpcRequest; import github.javaguide.dto.RpcRequest;
import github.javaguide.dto.RpcResponse; import github.javaguide.dto.RpcResponse;
import github.javaguide.registry.DefaultServiceRegistry;
import github.javaguide.registry.ServiceRegistry;
import github.javaguide.transport.RpcRequestHandler; import github.javaguide.transport.RpcRequestHandler;
import github.javaguide.utils.concurrent.ThreadPoolFactory; import github.javaguide.utils.concurrent.ThreadPoolFactory;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
...@@ -29,14 +27,13 @@ import java.util.concurrent.ExecutorService; ...@@ -29,14 +27,13 @@ import java.util.concurrent.ExecutorService;
public class NettyServerHandler extends ChannelInboundHandlerAdapter { public class NettyServerHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(NettyServerHandler.class); private static final Logger logger = LoggerFactory.getLogger(NettyServerHandler.class);
private static RpcRequestHandler rpcRequestHandler; private static final String THREAD_NAME_PREFIX = "netty-server-handler-rpc-pool";
private static ServiceRegistry serviceRegistry; private static final RpcRequestHandler rpcRequestHandler;
private static ExecutorService threadPool; private static final ExecutorService threadPool;
static { static {
rpcRequestHandler = new RpcRequestHandler(); rpcRequestHandler = new RpcRequestHandler();
serviceRegistry = new DefaultServiceRegistry(); threadPool = ThreadPoolFactory.createDefaultThreadPool(THREAD_NAME_PREFIX);
threadPool = ThreadPoolFactory.createDefaultThreadPool("netty-server-handler-rpc-pool");
} }
@Override @Override
...@@ -46,11 +43,8 @@ public class NettyServerHandler extends ChannelInboundHandlerAdapter { ...@@ -46,11 +43,8 @@ public class NettyServerHandler extends ChannelInboundHandlerAdapter {
try { try {
logger.info(String.format("server receive msg: %s", msg)); logger.info(String.format("server receive msg: %s", msg));
RpcRequest rpcRequest = (RpcRequest) msg; RpcRequest rpcRequest = (RpcRequest) msg;
String interfaceName = rpcRequest.getInterfaceName();
//通过注册中心获取到目标类(客户端需要调用类)
Object service = serviceRegistry.getService(interfaceName);
//执行目标方法(客户端需要执行的方法)并且返回方法结果 //执行目标方法(客户端需要执行的方法)并且返回方法结果
Object result = rpcRequestHandler.handle(rpcRequest, service); Object result = rpcRequestHandler.handle(rpcRequest);
logger.info(String.format("server get result: %s", result.toString())); logger.info(String.format("server get result: %s", result.toString()));
//返回方法执行结果给客户端 //返回方法执行结果给客户端
ChannelFuture f = ctx.writeAndFlush(RpcResponse.success(result, rpcRequest.getRequestId())); ChannelFuture f = ctx.writeAndFlush(RpcResponse.success(result, rpcRequest.getRequestId()));
...@@ -65,7 +59,7 @@ public class NettyServerHandler extends ChannelInboundHandlerAdapter { ...@@ -65,7 +59,7 @@ public class NettyServerHandler extends ChannelInboundHandlerAdapter {
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.error("server catch github.javaguide.exception"); logger.error("server catch exception");
cause.printStackTrace(); cause.printStackTrace();
ctx.close(); ctx.close();
} }
......
...@@ -2,8 +2,6 @@ package github.javaguide.transport.socket; ...@@ -2,8 +2,6 @@ package github.javaguide.transport.socket;
import github.javaguide.dto.RpcRequest; import github.javaguide.dto.RpcRequest;
import github.javaguide.dto.RpcResponse; import github.javaguide.dto.RpcResponse;
import github.javaguide.registry.DefaultServiceRegistry;
import github.javaguide.registry.ServiceRegistry;
import github.javaguide.transport.RpcRequestHandler; import github.javaguide.transport.RpcRequestHandler;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -20,12 +18,10 @@ import java.net.Socket; ...@@ -20,12 +18,10 @@ import java.net.Socket;
public class SocketRpcRequestHandlerRunnable implements Runnable { public class SocketRpcRequestHandlerRunnable implements Runnable {
private static final Logger logger = LoggerFactory.getLogger(SocketRpcRequestHandlerRunnable.class); private static final Logger logger = LoggerFactory.getLogger(SocketRpcRequestHandlerRunnable.class);
private Socket socket; private Socket socket;
private static RpcRequestHandler rpcRequestHandler; private static final RpcRequestHandler rpcRequestHandler;
private static ServiceRegistry serviceRegistry;
static { static {
rpcRequestHandler = new RpcRequestHandler(); rpcRequestHandler = new RpcRequestHandler();
serviceRegistry = new DefaultServiceRegistry();
} }
public SocketRpcRequestHandlerRunnable(Socket socket) { public SocketRpcRequestHandlerRunnable(Socket socket) {
...@@ -38,13 +34,11 @@ public class SocketRpcRequestHandlerRunnable implements Runnable { ...@@ -38,13 +34,11 @@ public class SocketRpcRequestHandlerRunnable implements Runnable {
try (ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream()); try (ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream());
ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream())) { ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream())) {
RpcRequest rpcRequest = (RpcRequest) objectInputStream.readObject(); RpcRequest rpcRequest = (RpcRequest) objectInputStream.readObject();
String interfaceName = rpcRequest.getInterfaceName(); Object result = rpcRequestHandler.handle(rpcRequest);
Object service = serviceRegistry.getService(interfaceName);
Object result = rpcRequestHandler.handle(rpcRequest, service);
objectOutputStream.writeObject(RpcResponse.success(result, rpcRequest.getRequestId())); objectOutputStream.writeObject(RpcResponse.success(result, rpcRequest.getRequestId()));
objectOutputStream.flush(); objectOutputStream.flush();
} catch (IOException | ClassNotFoundException e) { } catch (IOException | ClassNotFoundException e) {
logger.error("occur github.javaguide.exception:", e); logger.error("occur exception:", e);
} }
} }
......
...@@ -15,7 +15,7 @@ import java.util.concurrent.ExecutorService; ...@@ -15,7 +15,7 @@ import java.util.concurrent.ExecutorService;
*/ */
public class SocketRpcServer { public class SocketRpcServer {
private ExecutorService threadPool; private final ExecutorService threadPool;
private static final Logger logger = LoggerFactory.getLogger(SocketRpcServer.class); private static final Logger logger = LoggerFactory.getLogger(SocketRpcServer.class);
public SocketRpcServer() { public SocketRpcServer() {
...@@ -24,7 +24,7 @@ public class SocketRpcServer { ...@@ -24,7 +24,7 @@ public class SocketRpcServer {
public void start(int port) { public void start(int port) {
try (ServerSocket server = new ServerSocket(port);) { try (ServerSocket server = new ServerSocket(port)) {
logger.info("server starts..."); logger.info("server starts...");
Socket socket; Socket socket;
while ((socket = server.accept()) != null) { while ((socket = server.accept()) != null) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册