From 1366832a41785ece0480dbf5d997b80f4080af7a Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sun, 1 Jul 2018 18:04:48 +0800 Subject: [PATCH] add dist pass barrier --- paddle/fluid/framework/executor.cc | 18 +++++++--- paddle/fluid/framework/executor.h | 9 +++-- .../operators/distributed/grpc_client.cc | 29 ++++++++++++--- .../fluid/operators/distributed/grpc_client.h | 24 ++++++++----- .../operators/distributed/request_handler.h | 3 ++ .../distributed/request_handler_impl.cc | 36 +++++++++---------- .../fluid/operators/distributed/rpc_client.cc | 2 +- .../fluid/operators/distributed/rpc_client.h | 15 +++++--- .../fluid/operators/distributed/rpc_server.cc | 22 ++++++++++-- .../fluid/operators/distributed/rpc_server.h | 8 ++++- paddle/fluid/pybind/pybind.cc | 3 +- python/paddle/fluid/executor.py | 6 ++++ 12 files changed, 128 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index ae98fccc960..87b0ff0c802 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -48,10 +48,20 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} #ifdef PADDLE_WITH_DISTRIBUTE -void Executor::Complete() { - ::paddle::operators::distributed::RPCClient::GetInstance< - ::paddle::operators::distributed::GRPCClient>() - ->SendComplete(); +void Executor::BeginPass() { + auto client = ::paddle::operators::distributed::RPCClient::GetInstance< + ::paddle::operators::distributed::GRPCClient>(); + + client->SendBeginPass(); + client->Wait(); +} + +void Executor::EndPass() { + auto client = ::paddle::operators::distributed::RPCClient::GetInstance< + ::paddle::operators::distributed::GRPCClient>(); + + client->SendEndPass(); + client->Wait(); } #endif diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 3aa5ffef69c..563a4b2bb65 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -46,9 +46,14 @@ class Executor { #ifdef PADDLE_WITH_DISTRIBUTE /* - * Sending signal to pserver to mark current trainer stop. + * Sending signal to pserver to mark current pass started. */ - void Complete(); + void BeginPass(); + + /* + * Sending signal to pserver to mark current pass finished. + */ + void EndPass(); #endif /* @Brief diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 8228a8c5a3e..d8dc667fe74 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -35,9 +35,17 @@ void GRPCClient::InitEventLoop() { client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } -void GRPCClient::SendComplete() { +void GRPCClient::SendBeginPass() { for (auto& it : channels_) { - this->AsyncSendComplete(it.first); + VLOG(3) << "send begin pass to: " it.first; + this->AsyncSendBeginPass(it.first); + } +} + +void GRPCClient::SendEndPass() { + for (auto& it : channels_) { + VLOG(3) << "send end pass to " << it.first; + this->AsyncSendEndPass(it.first); } } @@ -226,19 +234,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, req_count_++; } -void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); sendrecv::VariableMessage req; - req.set_varname(COMPLETE_MESSAGE); + req.set_varname(BEGIN_PASS_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } +void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + + FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(END_PASS_MESSAGE); + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; +} + void GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 7a08f2d3a4a..5dae20155ed 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -77,11 +77,12 @@ class BaseProcessor { context_.reset(new grpc::ClientContext()); var_h_ = var_info; context_->set_wait_for_ready(true); - - std::chrono::system_clock::time_point deadline = - std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); - - context_->set_deadline(deadline); + if (time_out) { + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + + std::chrono::milliseconds(time_out); + context_->set_deadline(deadline); + } } virtual void Prepare(int64_t time_out) { @@ -214,9 +215,17 @@ class GRPCClient : public RPCClient { void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) override; + void AsyncSendBeginPass(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) override; + + void AsyncSendEndPass(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) override; + void Wait() override; - void SendComplete() override; + void SendBeginPass() override; + + void SendEndPass() override; protected: void InitImpl() override; @@ -227,9 +236,6 @@ class GRPCClient : public RPCClient { void Proceed(); - void AsyncSendComplete(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline); - std::shared_ptr GetChannel(const std::string& ep); private: diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 90742a201ad..271306d5d20 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; +constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" +#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV" +#define END_PASS_MESSAGE "END_PASS@RECV" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 163154c678f..5e6bff20f5f 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname, if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv batch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestSend); - } else if (varname == COMPLETE_MESSAGE) { - VLOG(3) << "sync: recv complete message"; - rpc_server_->DecreaseClientNum(); + } else if (varname == BEGIN_PASS_MESSAGE) { + VLOG(3) << "sync: recv begin pass message"; + rpc_server_->WaitCond(kRequestSend); + rpc_server_->BeginPass(); } else { VLOG(3) << "sync: received var_name: " << varname; - if (sync_mode_) { - rpc_server_->WaitCond(kRequestSend); - } + rpc_server_->WaitCond(kRequestSend); + VLOG(3) << "sync: processing received var: " << varname; if (invar == nullptr) { LOG(ERROR) << "sync: Can not find server side var: " << varname; @@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname, framework::Variable** outvar, const std::string& out_var_name) { VLOG(4) << "RequestGetHandler:" << varname; - - if (varname != FETCH_BARRIER_MESSAGE) { - if (sync_mode_) { + if (sync_mode_) { + if (varname == FETCH_BARRIER_MESSAGE) { + VLOG(3) << "sync: recv fetch barrier message"; + rpc_server_->IncreaseBatchBarrier(kRequestGet); + } else if (varname == END_PASS_MESSAGE) { + rpc_server_->EndPass(); + } else { rpc_server_->WaitCond(kRequestGet); + *outvar = scope_->FindVar(varname); + } + } else { + if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) { + *outvar = scope_->FindVar(varname); } - *outvar = scope_->FindVar(varname); - return true; - } - - // FETCH_BARRIER_MESSAGE - if (sync_mode_) { - VLOG(3) << "sync: recv fetch barrier message"; - rpc_server_->IncreaseBatchBarrier(kRequestGet); } - return true; } diff --git a/paddle/fluid/operators/distributed/rpc_client.cc b/paddle/fluid/operators/distributed/rpc_client.cc index b5ec9fe5367..382b65d637c 100644 --- a/paddle/fluid/operators/distributed/rpc_client.cc +++ b/paddle/fluid/operators/distributed/rpc_client.cc @@ -16,7 +16,7 @@ #include "gflags/gflags.h" // default to 3min to avoid temprary network failures. -DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc"); +DEFINE_int32(rpc_deadline, 30000, "deadline timeouts for rpc"); namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 37783b78ecc..6479d3a97ba 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -60,10 +60,17 @@ class RPCClient { const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) = 0; - // SendComplete tells all the server that current trainer have no more data - // to train, so that the pserver can reduce it's barrier count, and continue - // to train with other trainers. - virtual void SendComplete() = 0; + virtual void AsyncSendBeginPass(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual void AsyncSendEndPass(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + // BeginePass/EndPass tells all the pserver that start/end a pass, so that + // the pserver can increase/reduce it's barrier count, and continue to train + // with other trainers. + virtual void SendBeginPass() = 0; + virtual void SendEndPass() = 0; virtual void Wait() = 0; diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index c0520e248d4..5f4c1348375 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -44,7 +44,8 @@ void RPCServer::SavePort() const { void RPCServer::WaitBarrier(const std::string& rpc_name) { std::unique_lock lock(this->mutex_); barrier_cond_.wait(lock, [this, &rpc_name] { - return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); + return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || + exit_flag_.load()); }); VLOG(3) << "batch_barrier_: " << rpc_name << " " @@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { } } -void RPCServer::DecreaseClientNum() { +void RPCServer::BeginPass() { + VLOG(4) << "RPCServer begin increase pass barrier"; { - std::unique_lock lock(mutex_); + std::unique_lock locl(mutex_); + client_num_++; + VLOG(4) << "increase client_num to: " << client_num_; + } + barrier_cond_.notify_all(); +} + +void RPCServer::EndPass() { + VLOG(4) << "RPCServer begin increase pass barrier"; + { + std::unique_lock locl(mutex_); client_num_--; + VLOG(4) << "decrease client_num to: " << client_num_; + if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { + barrier_counter_[kRequestGet]--; + } } barrier_cond_.notify_all(); } diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index cf25e78435b..833991c8aa6 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -43,6 +43,9 @@ class RPCServer { bool IsExit() { return exit_flag_.load(); } int GetSelectedPort() const { return selected_port_; } + + int GetClientNum() const; + void SavePort() const; // RegisterRPC, register the rpc method name to a handler @@ -60,7 +63,10 @@ class RPCServer { void SetCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); - void DecreaseClientNum(); + + void BeginPass(); + void EndPass(); + void ResetBarrierCounter(); protected: diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 36d08099683..3f1e2ceedbf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) #ifdef PADDLE_WITH_DISTRIBUTE - .def("complete", &Executor::Complete) + .def("begin_pass", &Executor::BeginPass) + .def("end_pass", &Executor::EndPass) #endif .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 145f1423e4b..b436dfe70af 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -348,6 +348,12 @@ class Executor(object): ] return outs + def begin_pass(self): + self.executor.begin_pass() + + def end_pass(self): + self.executor.end_pass() + def run(self, program=None, feed=None, -- GitLab