提交 83d006bd 编写于 作者: S shuang.kou

[v3.0]use completefuture to get rpcresponse from server

上级 7090a428
......@@ -125,3 +125,7 @@ IntelliJ IDEA-> Preferences->Plugins->搜索下载CheckStyle 插件,然后按
### 如果我要自己写的话,可以参考哪些资料?
......@@ -2,7 +2,7 @@ package github.javaguide;
import github.javaguide.remoting.transport.ClientTransport;
import github.javaguide.proxy.RpcClientProxy;
import github.javaguide.remoting.transport.netty.client.NettyClientClientTransport;
import github.javaguide.remoting.transport.netty.client.NettyClientTransport;
/**
* @author shuang.kou
......@@ -10,7 +10,7 @@ import github.javaguide.remoting.transport.netty.client.NettyClientClientTranspo
*/
public class NettyClientMain {
public static void main(String[] args) {
ClientTransport rpcClient = new NettyClientClientTransport();
ClientTransport rpcClient = new NettyClientTransport();
RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient);
HelloService helloService = rpcClientProxy.getProxy(HelloService.class);
String hello = helloService.hello(new Hello("111", "222"));
......
package github.javaguide.proxy;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.ClientTransport;
import github.javaguide.remoting.transport.netty.client.NettyClientTransport;
import github.javaguide.remoting.transport.socket.SocketRpcClient;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
/**
* 动态代理类。当动态代理对象调用一个方法的时候,实际调用的是下面的 invoke 方法。
......@@ -38,6 +43,8 @@ public class RpcClientProxy implements InvocationHandler {
/**
* 当你使用代理对象调用方法的时候实际会调用到这个方法。代理对象就是你通过上面的 getProxy 方法获取到的对象。
*/
@SneakyThrows
@SuppressWarnings("unchecked")
@Override
public Object invoke(Object proxy, Method method, Object[] args) {
log.info("Call invoke method and invoked method: {}", method.getName());
......@@ -47,6 +54,15 @@ public class RpcClientProxy implements InvocationHandler {
.paramTypes(method.getParameterTypes())
.requestId(UUID.randomUUID().toString())
.build();
return clientTransport.sendRpcRequest(rpcRequest);
Object result = null;
if (clientTransport instanceof NettyClientTransport) {
CompletableFuture<RpcResponse> completableFuture = (CompletableFuture<RpcResponse>) clientTransport.sendRpcRequest(rpcRequest);
result = completableFuture.get().getData();
}
if (clientTransport instanceof SocketRpcClient) {
RpcResponse rpcResponse = (RpcResponse) clientTransport.sendRpcRequest(rpcRequest);
result = rpcResponse.getData();
}
return result;
}
}
......@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit;
@Slf4j
public class ChannelProvider {
private static Bootstrap bootstrap = NettyClient.initializeBootstrap();
private static Bootstrap bootstrap = NettyClient.getBootstrap();
private static Channel channel = null;
/**
* 最多重试次数
......@@ -51,7 +51,6 @@ public class ChannelProvider {
return;
}
if (retry == 0) {
log.error("客户端连接失败:重试次数已用完,放弃连接!");
countDownLatch.countDown();
throw new RpcException(RpcErrorMessageEnum.CLIENT_CONNECT_SERVER_FAILURE);
}
......
......@@ -2,9 +2,9 @@ package github.javaguide.remoting.transport.netty.client;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.serialize.kyro.KryoSerializer;
import github.javaguide.remoting.transport.netty.codec.kyro.NettyKryoDecoder;
import github.javaguide.remoting.transport.netty.codec.kyro.NettyKryoEncoder;
import github.javaguide.serialize.kyro.KryoSerializer;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
......@@ -25,10 +25,8 @@ public final class NettyClient {
private static Bootstrap b;
private static EventLoopGroup eventLoopGroup;
private NettyClient() {
}
// 初始化相关资源比如 EventLoopGroup、Bootstrap
static {
eventLoopGroup = new NioEventLoopGroup();
b = new Bootstrap();
......@@ -59,7 +57,9 @@ public final class NettyClient {
eventLoopGroup.shutdownGracefully();
}
public static Bootstrap initializeBootstrap() {
public static Bootstrap getBootstrap() {
return b;
}
}
package github.javaguide.remoting.transport.netty.client;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.dto.RpcResponse;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;
......@@ -20,6 +20,11 @@ import lombok.extern.slf4j.Slf4j;
*/
@Slf4j
public class NettyClientHandler extends ChannelInboundHandlerAdapter {
private final UnprocessedRequests unprocessedRequests;
public NettyClientHandler() {
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
}
/**
* 读取服务端传输的消息
......@@ -29,15 +34,7 @@ public class NettyClientHandler extends ChannelInboundHandlerAdapter {
try {
log.info("client receive msg: [{}]", msg);
RpcResponse rpcResponse = (RpcResponse) msg;
// 声明一个 AttributeKey 对象,类似于 Map 中的 key
AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse" + rpcResponse.getRequestId());
/*
* AttributeMap 可以看作是一个Channel的共享数据源
* AttributeMap 的 key 是 AttributeKey,value 是 Attribute
*/
// 将服务端的返回结果保存到 AttributeMap 上
ctx.channel().attr(key).set(rpcResponse);
ctx.channel().close();
unprocessedRequests.complete(rpcResponse);
} finally {
ReferenceCountUtil.release(msg);
}
......
package github.javaguide.remoting.transport.netty.client;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.registry.ZkServiceDiscovery;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.ClientTransport;
import github.javaguide.remoting.dto.RpcMessageChecker;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.CompletableFuture;
/**
* 基于 Netty 传输 RpcRequest。
......@@ -21,45 +20,46 @@ import java.util.concurrent.atomic.AtomicReference;
* @createTime 2020年05月29日 11:34:00
*/
@Slf4j
public class NettyClientClientTransport implements ClientTransport {
public class NettyClientTransport implements ClientTransport {
private final ServiceDiscovery serviceDiscovery;
private final UnprocessedRequests unprocessedRequests;
public NettyClientClientTransport() {
public NettyClientTransport() {
this.serviceDiscovery = new ZkServiceDiscovery();
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
}
@Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
AtomicReference<Object> result = new AtomicReference<>(null);
public CompletableFuture<RpcResponse> sendRpcRequest(RpcRequest rpcRequest) {
// 构建返回值
CompletableFuture<RpcResponse> resultFuture = new CompletableFuture<>();
try {
InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcRequest.getInterfaceName());
Channel channel = ChannelProvider.get(inetSocketAddress);
if (!channel.isActive()) {
NettyClient.close();
return null;
if (channel != null && channel.isActive()) {
// 放入未处理的请求
unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
channel.writeAndFlush(rpcRequest).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
log.info("client send message: {}", rpcRequest);
} else {
future.channel().close();
resultFuture.completeExceptionally(future.cause());
log.error("Send failed:", future.cause());
}
});
} else {
throw new IllegalStateException();
}
channel.writeAndFlush(rpcRequest).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
log.info("client send message: {}", rpcRequest);
} else {
future.channel().close();
log.error("Send failed:", future.cause());
}
});
channel.closeFuture().sync();
AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse" + rpcRequest.getRequestId());
RpcResponse rpcResponse = channel.attr(key).get();
log.info("client get rpcResponse from channel:{}", rpcResponse);
//校验 RpcResponse 和 RpcRequest
RpcMessageChecker.check(rpcResponse, rpcRequest);
result.set(rpcResponse.getData());
} catch (InterruptedException e) {
unprocessedRequests.remove(rpcRequest.getRequestId());
log.error(e.getMessage(), e);
Thread.currentThread().interrupt();
}
return result.get();
return resultFuture;
}
}
package github.javaguide.remoting.transport.netty.client;
import github.javaguide.remoting.dto.RpcResponse;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
/**
* 未处理的请求。
*
* @author shuang.kou
* @createTime 2020年06月04日 17:30:00
*/
public class UnprocessedRequests {
private static ConcurrentHashMap<String, CompletableFuture<RpcResponse>> unprocessedResponseFutures = new ConcurrentHashMap<>();
public void put(String requestId, CompletableFuture<RpcResponse> future) {
unprocessedResponseFutures.put(requestId, future);
}
public void remove(String requestId) {
unprocessedResponseFutures.remove(requestId);
}
public void complete(RpcResponse rpcResponse) {
CompletableFuture<RpcResponse> future = unprocessedResponseFutures.remove(rpcResponse.getRequestId());
if (null != future) {
future.complete(rpcResponse);
} else {
throw new IllegalStateException();
}
}
}
package github.javaguide.remoting.transport.netty.server;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.handler.RpcRequestHandler;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.handler.RpcRequestHandler;
import github.javaguide.utils.concurrent.ThreadPoolFactoryUtils;
import github.javaguide.factory.SingletonFactory;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
......@@ -39,16 +37,18 @@ public class NettyServerHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
threadPool.execute(() -> {
log.info(String.format("server handle message from client by thread: %s", Thread.currentThread().getName()));
try {
log.info(String.format("server receive msg: %s", msg));
log.info("server receive msg: [{}] ", msg);
RpcRequest rpcRequest = (RpcRequest) msg;
//执行目标方法(客户端需要执行的方法)并且返回方法结果
Object result = rpcRequestHandler.handle(rpcRequest);
log.info(String.format("server get result: %s", result.toString()));
//返回方法执行结果给客户端
ChannelFuture f = ctx.writeAndFlush(RpcResponse.success(result, rpcRequest.getRequestId()));
f.addListener(ChannelFutureListener.CLOSE);
if (ctx.channel().isActive() && ctx.channel().isWritable()) {
//返回方法执行结果给客户端
ctx.writeAndFlush(RpcResponse.success(result, rpcRequest.getRequestId()));
} else {
log.error("not writable now, message dropped");
}
} finally {
//确保 ByteBuf 被释放,不然可能会有内存泄露问题
ReferenceCountUtil.release(msg);
......
package github.javaguide.remoting.transport.socket;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.exception.RpcException;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.registry.ZkServiceDiscovery;
import github.javaguide.remoting.transport.ClientTransport;
import github.javaguide.remoting.dto.RpcMessageChecker;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.ClientTransport;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
......@@ -44,7 +44,7 @@ public class SocketRpcClient implements ClientTransport {
RpcResponse rpcResponse = (RpcResponse) objectInputStream.readObject();
//校验 RpcResponse 和 RpcRequest
RpcMessageChecker.check(rpcResponse, rpcRequest);
return rpcResponse.getData();
return rpcResponse;
} catch (IOException | ClassNotFoundException e) {
throw new RpcException("调用服务失败:", e);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册