未验证 提交 6133efd9 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #12218 from Yancey1989/rpc_complete_interface

Add rpc complete interface
......@@ -35,8 +35,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.begin_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.end_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
......
......@@ -45,19 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::BeginPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendBeginPass();
}
void Executor::EndPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendEndPass();
}
->SendComplete();
#endif
}
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
......
......@@ -44,17 +44,11 @@ class Executor {
explicit Executor(const platform::Place& place);
#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current pass started.
* Close this Executor.
* Calling this method will send complete messages to all pserver instances.
*/
void BeginPass();
/*
* Sending signal to pserver to mark current pass finished.
*/
void EndPass();
#endif
void Close();
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
......
......@@ -18,7 +18,7 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(grpc_server_test SRCS rpc_server_test.cc
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
return()
endif()
......
......@@ -36,20 +36,16 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
void GRPCClient::SendBeginPass() {
for (auto& it : channels_) {
VLOG(3) << "send begin pass to: " << it.first;
this->AsyncSendBeginPass(it.first);
}
this->Wait();
}
void GRPCClient::SendEndPass() {
for (auto& it : channels_) {
VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first);
void GRPCClient::SendComplete() {
std::unique_lock<std::mutex> lk(completed_mutex_);
if (!completed_) {
for (auto& it : channels_) {
VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first);
}
PADDLE_ENFORCE(this->Wait(), "internal grpc error");
completed_ = true;
}
this->Wait();
}
GRPCClient::~GRPCClient() {
......@@ -239,32 +235,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++;
}
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(BEGIN_PASS_MESSAGE);
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(END_PASS_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
......
......@@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class GRPCClient : public RPCClient {
public:
GRPCClient() : ok_(true) {}
GRPCClient() : ok_(true), completed_(false) {}
virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
......@@ -201,17 +201,12 @@ class GRPCClient : public RPCClient {
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendBeginPass() override;
void SendEndPass() override;
void SendComplete() override;
protected:
void InitImpl() override;
......@@ -238,6 +233,10 @@ class GRPCClient : public RPCClient {
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(GRPCClient);
// mutex for sending complete message only once
std::mutex completed_mutex_;
bool completed_;
};
} // namespace distributed
......
......@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv begin pass message";
rpc_server_->WaitCond(kRequestSend);
rpc_server_->BeginPass();
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->Complete();
} else {
VLOG(3) << "sync: received var_name: " << varname;
rpc_server_->WaitCond(kRequestSend);
......@@ -94,14 +93,12 @@ bool RequestGetHandler::Handle(const std::string& varname,
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else if (varname == END_PASS_MESSAGE) {
rpc_server_->EndPass();
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
*outvar = scope_->FindVar(varname);
}
}
......
......@@ -60,17 +60,13 @@ class RPCClient {
const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train
// with other trainers.
virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0;
virtual void SendComplete() = 0;
virtual bool Wait() = 0;
......
......@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
}
}
void RPCServer::BeginPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_++;
VLOG(4) << "increase client_num to: " << client_num_;
}
barrier_cond_.notify_all();
}
void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
void RPCServer::Complete() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
......@@ -87,6 +76,11 @@ void RPCServer::EndPass() {
barrier_cond_.notify_all();
}
int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_);
return client_num_;
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
......
......@@ -44,7 +44,7 @@ class RPCServer {
int GetSelectedPort() const { return selected_port_; }
int GetClientNum() const;
int GetClientNum();
void SavePort() const;
......@@ -64,8 +64,7 @@ class RPCServer {
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void BeginPass();
void EndPass();
void Complete();
void ResetBarrierCounter();
......
......@@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void StartServer() {
void StartServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -107,14 +107,14 @@ void StartServer() {
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
g_req_handler.get());
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
......@@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) {
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
......@@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) {
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(COMPLETE, CPU) {
g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep);
client->Wait();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
g_rpc_service->ShutDown();
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
......@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("begin_pass", &Executor::BeginPass)
.def("end_pass", &Executor::EndPass)
#endif
.def("close", &Executor::Close)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) {
pybind11::gil_scoped_release release;
......
......@@ -247,6 +247,7 @@ class Executor(object):
p.set_place(place)
self.executor = core.Executor(p)
self.program_caches = dict()
self._closed = False
def as_lodtensor(self, data):
"""
......@@ -348,11 +349,23 @@ class Executor(object):
]
return outs
def begin_pass(self):
self.executor.begin_pass()
def close(self):
"""
Close this executor.
def end_pass(self):
self.executor.end_pass()
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if not self._closed:
self.executor.close()
self._closed = True
def run(self,
program=None,
......@@ -405,6 +418,10 @@ class Executor(object):
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if feed is None:
feed = {}
if not isinstance(feed, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册