diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index abee6698e30b7e76ca42825ed225876bf2ba5ec0..79b2449fe6689993bbee8a24ae7c46b452afe0a0 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -33,7 +33,7 @@ ExternalProject_Add( extern_grpc DEPENDS protobuf zlib GIT_REPOSITORY "https://github.com/grpc/grpc.git" - GIT_TAG "v1.7.x" + GIT_TAG "v1.8.x" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index 5a4db2d7e686ce84abef620f890be8f3aa82cb73..aee56ffe018aa8d0d2106df24bd9358c930a02ca 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::wait() { +bool RPCClient::Wait() { bool ok = true; while (true) { @@ -96,7 +96,6 @@ bool RPCClient::wait() { } if (!Proceed()) { - LOG(ERROR) << "Get meets CompletionQueue error"; return false; } } @@ -110,9 +109,9 @@ bool RPCClient::Proceed() { // request counts. if (!cq_.Next(&tag, &ok)) { + LOG(ERROR) << "Get meets CompletionQueue error"; return false; } - req_count_--; GPR_ASSERT(ok); PADDLE_ENFORCE(tag); @@ -120,12 +119,15 @@ bool RPCClient::Proceed() { // TODO(gongwb): add more retries. ClientBase* c = static_cast(tag); if (!c->status_.ok()) { + LOG(ERROR) << "proc param error:" << c->var_h_.String() + << " grpc error:" << c->status_.error_message(); delete c; - return true; + return false; } c->Process(); delete c; + req_count_--; return true; } @@ -135,8 +137,12 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { return it->second; } + grpc::ChannelArguments args; + args.SetMaxSendMessageSize(std::numeric_limits::max()); + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + auto ch = std::shared_ptr( - grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())); + grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args)); channels_[ep] = ch; return ch; diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index d27b5ced9ece67f9b9da3b7f87ec231477603580..a62e70a2533ae52d84d010504b19fed5aeb15dc0 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -130,7 +130,7 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = 600 * 1000); - bool wait(); + bool Wait(); private: bool Proceed(); diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index e8d561a57ff59e9221400241f881cb26fb6c6f06..ac4bb5cb8227b89a2f387801e4f876b09eabbff3 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -28,12 +28,15 @@ class RequestBase { public: explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq) - : service_(service), cq_(cq), status_(PROCESS) {} + : service_(service), cq_(cq), status_(PROCESS) { + PADDLE_ENFORCE(cq_); + } virtual ~RequestBase() {} virtual void Process() { assert(false); } CallStatus Status() { return status_; } void SetStatus(CallStatus status) { status_ = status; } + virtual std::string GetReqName() { assert(false); } protected: grpc::ServerContext ctx_; @@ -56,12 +59,14 @@ class RequestSend final : public RequestBase { virtual ~RequestSend() {} + virtual std::string GetReqName() { return request_.varname(); } + virtual void Process() { MessageWithName msg_with_name = std::make_pair(request_.varname(), std::move(request_)); queue_->Push(std::move(msg_with_name)); - // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); + status_ = FINISH; } protected: @@ -81,6 +86,8 @@ class RequestGet final : public RequestBase { virtual ~RequestGet() {} + virtual std::string GetReqName() { return request_.varname(); } + virtual void Process() { // proc request. std::string var_name = request_.varname(); @@ -88,6 +95,7 @@ class RequestGet final : public RequestBase { SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); + status_ = FINISH; } protected: @@ -100,6 +108,8 @@ class RequestGet final : public RequestBase { void AsyncGRPCServer::RunSyncUpdate() { grpc::ServerBuilder builder; builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); + builder.SetMaxSendMessageSize(std::numeric_limits::max()); + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.RegisterService(&service_); cq_send_ = builder.AddCompletionQueue(); @@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(4) << "create Requestget status:" << get->Status(); } -void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - delete last; - last = NULL; - return; - } - - last->SetStatus(FINISH); - return; -} - void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { @@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, break; } + PADDLE_ENFORCE(tag); if (wait && !done_) { Wait(); } RequestBase* base = (RequestBase*)tag; + // reference: + // https://github.com/tensorflow/tensorflow/issues/5596 + // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM + // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I if (!ok) { - VLOG(4) << cq_name << " recv no regular event"; + LOG(WARNING) << cq_name << " recv no regular event:argument name" + << base->GetReqName(); TryToRegisterNewOne(); delete base; continue; @@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, VLOG(4) << cq_name << " status:" << base->Status(); TryToRegisterNewOne(); base->Process(); - SetFinishOrDelete(base); break; } case FINISH: { diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 041fe05b2e9c37e8a91669b8f523c47b56e14cba..694e18ef49818f2d22789748d930d041de4f3586 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); - void SetFinishOrDelete(RequestBase *&last); void ShutdownQueue(); private: diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 55b33343af43802e1b6b95a32603bfee806c9764..cf69c12b6864b4aaea82ecf7e9eb38e44d61a215 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase { rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; + VLOG(4) << "param_count:" << param_count + << " trainer_count:" << trainer_count; while (!exit_flag) { // TODO(gognwb): simply this loop. // Get from multiple trainers, we don't care about order in which diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 4d145250bdc73607c8817e20fdb753f4c96e2391..203000f5aaff1a89ac6119b5a9b774f8d48c7c76 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase { client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - client_.wait(); + PADDLE_ENFORCE(client_.Wait()); } private: