提交 1366832a 编写于 作者: Y Yancey1989

add dist pass barrier

上级 5988d0c0
...@@ -48,10 +48,20 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -48,10 +48,20 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() { void Executor::BeginPass() {
::paddle::operators::distributed::RPCClient::GetInstance< auto client = ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>();
->SendComplete();
client->SendBeginPass();
client->Wait();
}
void Executor::EndPass() {
auto client = ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>();
client->SendEndPass();
client->Wait();
} }
#endif #endif
......
...@@ -46,9 +46,14 @@ class Executor { ...@@ -46,9 +46,14 @@ class Executor {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
/* /*
* Sending signal to pserver to mark current trainer stop. * Sending signal to pserver to mark current pass started.
*/ */
void Complete(); void BeginPass();
/*
* Sending signal to pserver to mark current pass finished.
*/
void EndPass();
#endif #endif
/* @Brief /* @Brief
......
...@@ -35,9 +35,17 @@ void GRPCClient::InitEventLoop() { ...@@ -35,9 +35,17 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
void GRPCClient::SendComplete() { void GRPCClient::SendBeginPass() {
for (auto& it : channels_) { for (auto& it : channels_) {
this->AsyncSendComplete(it.first); VLOG(3) << "send begin pass to: " it.first;
this->AsyncSendBeginPass(it.first);
}
}
void GRPCClient::SendEndPass() {
for (auto& it : channels_) {
VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first);
} }
} }
...@@ -226,19 +234,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -226,19 +234,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(BEGIN_PASS_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(END_PASS_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep, void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dir,
int64_t time_out) { int64_t time_out) {
......
...@@ -77,11 +77,12 @@ class BaseProcessor { ...@@ -77,11 +77,12 @@ class BaseProcessor {
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
var_h_ = var_info; var_h_ = var_info;
context_->set_wait_for_ready(true); context_->set_wait_for_ready(true);
if (time_out) {
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); std::chrono::system_clock::now() +
std::chrono::milliseconds(time_out);
context_->set_deadline(deadline); context_->set_deadline(deadline);
}
} }
virtual void Prepare(int64_t time_out) { virtual void Prepare(int64_t time_out) {
...@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient { ...@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; void Wait() override;
void SendComplete() override; void SendBeginPass() override;
void SendEndPass() override;
protected: protected:
void InitImpl() override; void InitImpl() override;
...@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient { ...@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private: private:
......
...@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend"; ...@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) { } else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv complete message"; VLOG(3) << "sync: recv begin pass message";
rpc_server_->DecreaseClientNum(); rpc_server_->WaitCond(kRequestSend);
rpc_server_->BeginPass();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) { rpc_server_->WaitCond(kRequestSend);
rpc_server_->WaitCond(kRequestSend); VLOG(3) << "sync: processing received var: " << varname;
}
if (invar == nullptr) { if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname; LOG(ERROR) << "sync: Can not find server side var: " << varname;
...@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) {
if (varname != FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
if (sync_mode_) { VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else if (varname == END_PASS_MESSAGE) {
rpc_server_->EndPass();
} else {
rpc_server_->WaitCond(kRequestGet); rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
*outvar = scope_->FindVar(varname);
} }
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} }
return true; return true;
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
// default to 3min to avoid temprary network failures. // default to 3min to avoid temprary network failures.
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc"); DEFINE_int32(rpc_deadline, 30000, "deadline timeouts for rpc");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -60,10 +60,17 @@ class RPCClient { ...@@ -60,10 +60,17 @@ class RPCClient {
const std::string& dir, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data virtual void AsyncSendBeginPass(const std::string& ep,
// to train, so that the pserver can reduce it's barrier count, and continue int64_t time_out = FLAGS_rpc_deadline) = 0;
// to train with other trainers.
virtual void SendComplete() = 0; virtual void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// with other trainers.
virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0;
virtual void Wait() = 0; virtual void Wait() = 0;
......
...@@ -44,7 +44,8 @@ void RPCServer::SavePort() const { ...@@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] { barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load());
}); });
VLOG(3) << "batch_barrier_: " << rpc_name << " " VLOG(3) << "batch_barrier_: " << rpc_name << " "
...@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { ...@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
} }
} }
void RPCServer::DecreaseClientNum() { void RPCServer::BeginPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> locl(mutex_);
client_num_++;
VLOG(4) << "increase client_num to: " << client_num_;
}
barrier_cond_.notify_all();
}
void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> locl(mutex_);
client_num_--; client_num_--;
VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
}
} }
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
......
...@@ -43,6 +43,9 @@ class RPCServer { ...@@ -43,6 +43,9 @@ class RPCServer {
bool IsExit() { return exit_flag_.load(); } bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; } int GetSelectedPort() const { return selected_port_; }
int GetClientNum() const;
void SavePort() const; void SavePort() const;
// RegisterRPC, register the rpc method name to a handler // RegisterRPC, register the rpc method name to a handler
...@@ -60,7 +63,10 @@ class RPCServer { ...@@ -60,7 +63,10 @@ class RPCServer {
void SetCond(const std::string& rpc_name); void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void DecreaseClientNum();
void BeginPass();
void EndPass();
void ResetBarrierCounter(); void ResetBarrierCounter();
protected: protected:
......
...@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete) .def("begin_pass", &Executor::BeginPass)
.def("end_pass", &Executor::EndPass)
#endif #endif
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars) {
......
...@@ -348,6 +348,12 @@ class Executor(object): ...@@ -348,6 +348,12 @@ class Executor(object):
] ]
return outs return outs
def begin_pass(self):
self.executor.begin_pass()
def end_pass(self):
self.executor.end_pass()
def run(self, def run(self,
program=None, program=None,
feed=None, feed=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册