提交 7570d8e7 编写于 作者: Y Yancey1989

add rpc complete interface

上级 968922bd
...@@ -46,16 +46,10 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -46,16 +46,10 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::BeginPass() { void Executor::Complete() {
::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>()
->SendBeginPass(); ->SendComplete();
}
void Executor::EndPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendEndPass();
} }
#endif #endif
......
...@@ -46,14 +46,10 @@ class Executor { ...@@ -46,14 +46,10 @@ class Executor {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
/* /*
* Sending signal to pserver to mark current pass started. * Sending signal to pserver to mark current trainer completed.
*/ */
void BeginPass(); void Complete();
/*
* Sending signal to pserver to mark current pass finished.
*/
void EndPass();
#endif #endif
/* @Brief /* @Brief
......
...@@ -6,7 +6,7 @@ if(WITH_GRPC) ...@@ -6,7 +6,7 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
proto_desc lookup_table_op SERIAL) proto_desc lookup_table_op SERIAL)
return() return()
......
...@@ -35,18 +35,10 @@ void GRPCClient::InitEventLoop() { ...@@ -35,18 +35,10 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
void GRPCClient::SendBeginPass() { void GRPCClient::SendComplete() {
for (auto& it : channels_) { for (auto& it : channels_) {
VLOG(3) << "send begin pass to: " << it.first; VLOG(3) << "send complete message to " << it.first;
this->AsyncSendBeginPass(it.first); this->AsyncSendComplete(it.first);
}
this->Wait();
}
void GRPCClient::SendEndPass() {
for (auto& it : channels_) {
VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first);
} }
this->Wait(); this->Wait();
} }
...@@ -238,32 +230,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -238,32 +230,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BEGIN_PASS_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(END_PASS_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep, void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dir,
int64_t time_out) { int64_t time_out) {
......
...@@ -215,17 +215,12 @@ class GRPCClient : public RPCClient { ...@@ -215,17 +215,12 @@ class GRPCClient : public RPCClient {
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBeginPass(const std::string& ep, void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
void SendBeginPass() override; void SendComplete() override;
void SendEndPass() override;
protected: protected:
void InitImpl() override; void InitImpl() override;
......
...@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; ...@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == BEGIN_PASS_MESSAGE) { } else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv begin pass message"; VLOG(3) << "sync: recv complete message";
rpc_server_->WaitCond(kRequestSend); rpc_server_->Complete();
rpc_server_->BeginPass();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; VLOG(3) << "sync: received var_name: " << varname;
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
...@@ -95,14 +94,12 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -95,14 +94,12 @@ bool RequestGetHandler::Handle(const std::string& varname,
if (varname == FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message"; VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet); rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else if (varname == END_PASS_MESSAGE) {
rpc_server_->EndPass();
} else { } else {
rpc_server_->WaitCond(kRequestGet); rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname); *outvar = scope_->FindVar(varname);
} }
} else { } else {
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) { if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
*outvar = scope_->FindVar(varname); *outvar = scope_->FindVar(varname);
} }
} }
......
...@@ -60,17 +60,13 @@ class RPCClient { ...@@ -60,17 +60,13 @@ class RPCClient {
const std::string& dir, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBeginPass(const std::string& ep, virtual void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendEndPass(const std::string& ep, // Complete tells all the pserver instances that finishe the training,
int64_t time_out = FLAGS_rpc_deadline) = 0; // the pserver can reduce it's barrier count, and continue to train
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// with other trainers. // with other trainers.
virtual void SendBeginPass() = 0; virtual void SendComplete() = 0;
virtual void SendEndPass() = 0;
virtual bool Wait() = 0; virtual bool Wait() = 0;
......
...@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { ...@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
} }
} }
void RPCServer::BeginPass() { void RPCServer::Complete() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_++;
VLOG(4) << "increase client_num to: " << client_num_;
}
barrier_cond_.notify_all();
}
void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
client_num_--; client_num_--;
...@@ -83,10 +72,19 @@ void RPCServer::EndPass() { ...@@ -83,10 +72,19 @@ void RPCServer::EndPass() {
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--; barrier_counter_[kRequestGet]--;
} }
if (client_num_ == 0) {
exit_flag_ = true;
VLOG(1) << "No activate Trainer instance, PServer will exit...";
}
} }
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_);
return client_num_;
}
void RPCServer::ResetBarrierCounter() { void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter "; VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
......
...@@ -44,7 +44,7 @@ class RPCServer { ...@@ -44,7 +44,7 @@ class RPCServer {
int GetSelectedPort() const { return selected_port_; } int GetSelectedPort() const { return selected_port_; }
int GetClientNum() const; int GetClientNum();
void SavePort() const; void SavePort() const;
...@@ -64,8 +64,7 @@ class RPCServer { ...@@ -64,8 +64,7 @@ class RPCServer {
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void BeginPass(); void Complete();
void EndPass();
void ResetBarrierCounter(); void ResetBarrierCounter();
......
...@@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, ...@@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
} }
} }
void StartServer() { void StartServer(const std::string& rpc_name) {
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
...@@ -107,14 +107,14 @@ void StartServer() { ...@@ -107,14 +107,14 @@ void StartServer() {
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared; prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0]; prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program); g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx); g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(distributed::kRequestPrefetch, g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
...@@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) { ...@@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) {
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer); std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
...@@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) { ...@@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) {
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr); g_req_handler.reset(nullptr);
} }
TEST(COMPLETE, CPU) {
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep);
client->Wait();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
g_rpc_service->ShutDown();
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
...@@ -503,8 +503,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -503,8 +503,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
.def("begin_pass", &Executor::BeginPass) .def("complete", &Executor::Complete)
.def("end_pass", &Executor::EndPass)
#endif #endif
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars) {
......
...@@ -348,11 +348,8 @@ class Executor(object): ...@@ -348,11 +348,8 @@ class Executor(object):
] ]
return outs return outs
def begin_pass(self): def complete(self):
self.executor.begin_pass() self.executor.complete()
def end_pass(self):
self.executor.end_pass()
def run(self, def run(self,
program=None, program=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册