提交 7570d8e7 编写于 作者: Y Yancey1989

add rpc complete interface

上级 968922bd
......@@ -46,16 +46,10 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::BeginPass() {
void Executor::Complete() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendBeginPass();
}
void Executor::EndPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendEndPass();
->SendComplete();
}
#endif
......
......@@ -46,14 +46,10 @@ class Executor {
#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current pass started.
* Sending signal to pserver to mark current trainer completed.
*/
void BeginPass();
void Complete();
/*
* Sending signal to pserver to mark current pass finished.
*/
void EndPass();
#endif
/* @Brief
......
......@@ -6,7 +6,7 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
proto_desc lookup_table_op SERIAL)
return()
......
......@@ -35,18 +35,10 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
void GRPCClient::SendBeginPass() {
void GRPCClient::SendComplete() {
for (auto& it : channels_) {
VLOG(3) << "send begin pass to: " << it.first;
this->AsyncSendBeginPass(it.first);
}
this->Wait();
}
void GRPCClient::SendEndPass() {
for (auto& it : channels_) {
VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first);
VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first);
}
this->Wait();
}
......@@ -238,32 +230,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++;
}
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(BEGIN_PASS_MESSAGE);
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(END_PASS_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
......
......@@ -215,17 +215,12 @@ class GRPCClient : public RPCClient {
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendEndPass(const std::string& ep,
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendBeginPass() override;
void SendEndPass() override;
void SendComplete() override;
protected:
void InitImpl() override;
......
......@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv begin pass message";
rpc_server_->WaitCond(kRequestSend);
rpc_server_->BeginPass();
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->Complete();
} else {
VLOG(3) << "sync: received var_name: " << varname;
rpc_server_->WaitCond(kRequestSend);
......@@ -95,14 +94,12 @@ bool RequestGetHandler::Handle(const std::string& varname,
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else if (varname == END_PASS_MESSAGE) {
rpc_server_->EndPass();
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
*outvar = scope_->FindVar(varname);
}
}
......
......@@ -60,17 +60,13 @@ class RPCClient {
const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBeginPass(const std::string& ep,
virtual void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train
// with other trainers.
virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0;
virtual void SendComplete() = 0;
virtual bool Wait() = 0;
......
......@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
}
}
void RPCServer::BeginPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_++;
VLOG(4) << "increase client_num to: " << client_num_;
}
barrier_cond_.notify_all();
}
void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
void RPCServer::Complete() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
......@@ -83,10 +72,19 @@ void RPCServer::EndPass() {
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
}
if (client_num_ == 0) {
exit_flag_ = true;
VLOG(1) << "No activate Trainer instance, PServer will exit...";
}
}
barrier_cond_.notify_all();
}
int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_);
return client_num_;
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
......
......@@ -44,7 +44,7 @@ class RPCServer {
int GetSelectedPort() const { return selected_port_; }
int GetClientNum() const;
int GetClientNum();
void SavePort() const;
......@@ -64,8 +64,7 @@ class RPCServer {
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void BeginPass();
void EndPass();
void Complete();
void ResetBarrierCounter();
......
......@@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void StartServer() {
void StartServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -107,14 +107,14 @@ void StartServer() {
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
g_req_handler.get());
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
......@@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) {
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
......@@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) {
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(COMPLETE, CPU) {
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep);
client->Wait();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
g_rpc_service->ShutDown();
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
......@@ -503,8 +503,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("begin_pass", &Executor::BeginPass)
.def("end_pass", &Executor::EndPass)
.def("complete", &Executor::Complete)
#endif
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) {
......
......@@ -348,11 +348,8 @@ class Executor(object):
]
return outs
def begin_pass(self):
self.executor.begin_pass()
def end_pass(self):
self.executor.end_pass()
def complete(self):
self.executor.complete()
def run(self,
program=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册