From 7570d8e77cadf89760187a787b48693608cc8aaf Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 18 Jul 2018 16:22:19 +0800 Subject: [PATCH] add rpc complete interface --- paddle/fluid/framework/executor.cc | 10 ++---- paddle/fluid/framework/executor.h | 8 ++--- .../operators/distributed/CMakeLists.txt | 2 +- .../operators/distributed/grpc_client.cc | 31 +++---------------- .../fluid/operators/distributed/grpc_client.h | 11 ++----- .../operators/distributed/request_handler.h | 2 -- .../distributed/request_handler_impl.cc | 11 +++---- .../fluid/operators/distributed/rpc_client.h | 14 +++------ .../fluid/operators/distributed/rpc_server.cc | 22 ++++++------- .../fluid/operators/distributed/rpc_server.h | 5 ++- .../operators/distributed/rpc_server_test.cc | 29 ++++++++++++++--- paddle/fluid/pybind/pybind.cc | 3 +- python/paddle/fluid/executor.py | 7 ++--- 13 files changed, 62 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 84f67fafa1..750a08d3a3 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -46,16 +46,10 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} #ifdef PADDLE_WITH_DISTRIBUTE -void Executor::BeginPass() { +void Executor::Complete() { ::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::GRPCClient>() - ->SendBeginPass(); -} - -void Executor::EndPass() { - ::paddle::operators::distributed::RPCClient::GetInstance< - ::paddle::operators::distributed::GRPCClient>() - ->SendEndPass(); + ->SendComplete(); } #endif diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 563a4b2bb6..53ebf18bd2 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -46,14 +46,10 @@ class Executor { #ifdef PADDLE_WITH_DISTRIBUTE /* - * Sending signal to pserver to mark current pass started. + * Sending signal to pserver to mark current trainer completed. */ - void BeginPass(); + void Complete(); - /* - * Sending signal to pserver to mark current pass finished. - */ - void EndPass(); #endif /* @Brief diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 675ca36774..a6b1c43ce9 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -6,7 +6,7 @@ if(WITH_GRPC) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) - cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc + cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL) return() diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 4d60801b6a..7ef7482cab 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -35,18 +35,10 @@ void GRPCClient::InitEventLoop() { client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } -void GRPCClient::SendBeginPass() { +void GRPCClient::SendComplete() { for (auto& it : channels_) { - VLOG(3) << "send begin pass to: " << it.first; - this->AsyncSendBeginPass(it.first); - } - this->Wait(); -} - -void GRPCClient::SendEndPass() { - for (auto& it : channels_) { - VLOG(3) << "send end pass to " << it.first; - this->AsyncSendEndPass(it.first); + VLOG(3) << "send complete message to " << it.first; + this->AsyncSendComplete(it.first); } this->Wait(); } @@ -238,32 +230,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, req_count_++; } -void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); sendrecv::VariableMessage req; - req.set_varname(BEGIN_PASS_MESSAGE); + req.set_varname(COMPLETE_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } -void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); - - FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - s->Prepare(time_out); - - sendrecv::VariableMessage req; - req.set_varname(END_PASS_MESSAGE); - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; -} - void GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index d03a3e56ae..26cad7548e 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -215,17 +215,12 @@ class GRPCClient : public RPCClient { void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) override; - void AsyncSendBeginPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendEndPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; + void AsyncSendComplete(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; - void SendBeginPass() override; - - void SendEndPass() override; + void SendComplete() override; protected: void InitImpl() override; diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 271306d5d2..f68f9d8f3c 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" -#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV" -#define END_PASS_MESSAGE "END_PASS@RECV" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 5e6bff20f5..d9a1ac583f 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname, if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv batch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestSend); - } else if (varname == BEGIN_PASS_MESSAGE) { - VLOG(3) << "sync: recv begin pass message"; - rpc_server_->WaitCond(kRequestSend); - rpc_server_->BeginPass(); + } else if (varname == COMPLETE_MESSAGE) { + VLOG(3) << "sync: recv complete message"; + rpc_server_->Complete(); } else { VLOG(3) << "sync: received var_name: " << varname; rpc_server_->WaitCond(kRequestSend); @@ -95,14 +94,12 @@ bool RequestGetHandler::Handle(const std::string& varname, if (varname == FETCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv fetch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestGet); - } else if (varname == END_PASS_MESSAGE) { - rpc_server_->EndPass(); } else { rpc_server_->WaitCond(kRequestGet); *outvar = scope_->FindVar(varname); } } else { - if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) { + if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { *outvar = scope_->FindVar(varname); } } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 4d87376fbf..22a022a5d2 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -60,17 +60,13 @@ class RPCClient { const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) = 0; - virtual void AsyncSendBeginPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual void AsyncSendComplete(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) = 0; - virtual void AsyncSendEndPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - // BeginePass/EndPass tells all the pserver that start/end a pass, so that - // the pserver can increase/reduce it's barrier count, and continue to train + // Complete tells all the pserver instances that finishe the training, + // the pserver can reduce it's barrier count, and continue to train // with other trainers. - virtual void SendBeginPass() = 0; - virtual void SendEndPass() = 0; + virtual void SendComplete() = 0; virtual bool Wait() = 0; diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index d49ee34eea..42dc7bd10b 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { } } -void RPCServer::BeginPass() { - VLOG(4) << "RPCServer begin increase pass barrier"; - { - std::unique_lock lock(mutex_); - client_num_++; - VLOG(4) << "increase client_num to: " << client_num_; - } - barrier_cond_.notify_all(); -} - -void RPCServer::EndPass() { - VLOG(4) << "RPCServer begin increase pass barrier"; +void RPCServer::Complete() { { std::unique_lock lock(mutex_); client_num_--; @@ -83,10 +72,19 @@ void RPCServer::EndPass() { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { barrier_counter_[kRequestGet]--; } + if (client_num_ == 0) { + exit_flag_ = true; + VLOG(1) << "No activate Trainer instance, PServer will exit..."; + } } barrier_cond_.notify_all(); } +int RPCServer::GetClientNum() { + std::unique_lock lock(mutex_); + return client_num_; +} + void RPCServer::ResetBarrierCounter() { VLOG(3) << "RPCServer ResetBarrierCounter "; std::unique_lock lock(mutex_); diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index 833991c8aa..fd914d7a72 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -44,7 +44,7 @@ class RPCServer { int GetSelectedPort() const { return selected_port_; } - int GetClientNum() const; + int GetClientNum(); void SavePort() const; @@ -64,8 +64,7 @@ class RPCServer { void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); - void BeginPass(); - void EndPass(); + void Complete(); void ResetBarrierCounter(); diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index a0693cffab..9f2360ec70 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } } -void StartServer() { +void StartServer(const std::string& rpc_name) { framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; @@ -107,14 +107,14 @@ void StartServer() { std::shared_ptr> prefetch_var_name_to_prepared; prefetch_var_name_to_prepared[in_var_name] = prepared[0]; + g_req_handler->SetProgram(&program); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetDevCtx(&ctx); g_req_handler->SetScope(&scope); g_req_handler->SetExecutor(&exe); - g_rpc_service->RegisterRPC(distributed::kRequestPrefetch, - g_req_handler.get()); + g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( @@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) { distributed::RPCClient* client = distributed::RPCClient::GetInstance(); - std::thread server_thread(StartServer); + std::thread server_thread(StartServer, distributed::kRequestPrefetch); g_rpc_service->WaitServerReady(); int port = g_rpc_service->GetSelectedPort(); @@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) { g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); } + +TEST(COMPLETE, CPU) { + g_req_handler.reset(new distributed::RequestSendHandler(true)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(); + PADDLE_ENFORCE(client != nullptr); + std::thread server_thread(StartServer, distributed::kRequestSend); + g_rpc_service->WaitServerReady(); + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + client->AsyncSendComplete(ep); + client->Wait(); + + EXPECT_EQ(g_rpc_service->GetClientNum(), 1); + + g_rpc_service->ShutDown(); + server_thread.join(); + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); +} diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 216c4666c0..9669e4c083 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -503,8 +503,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) #ifdef PADDLE_WITH_DISTRIBUTE - .def("begin_pass", &Executor::BeginPass) - .def("end_pass", &Executor::EndPass) + .def("complete", &Executor::Complete) #endif .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index f9e600cb4c..d5f11619a3 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -348,11 +348,8 @@ class Executor(object): ] return outs - def begin_pass(self): - self.executor.begin_pass() - - def end_pass(self): - self.executor.end_pass() + def complete(self): + self.executor.complete() def run(self, program=None, -- GitLab