diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 4bdfad93187b0d7e18c4408a578cf5d7ad872b28..6efb03dabe89b28f3ff1a55c4a940dfe74e8001d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -35,8 +35,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Executor.begin_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Executor.end_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)) paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None) paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 84f67fafa19ac545ebb7a1019059e3c74c363c56..c2800c972a5501859672fbfd6921499e84d09cb0 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -45,19 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} +void Executor::Close() { #ifdef PADDLE_WITH_DISTRIBUTE -void Executor::BeginPass() { ::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::GRPCClient>() - ->SendBeginPass(); -} - -void Executor::EndPass() { - ::paddle::operators::distributed::RPCClient::GetInstance< - ::paddle::operators::distributed::GRPCClient>() - ->SendEndPass(); -} + ->SendComplete(); #endif +} void InitializeVariable(Variable* var, proto::VarType::Type var_type) { if (var_type == proto::VarType::LOD_TENSOR) { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 563a4b2bb65dad481a755f67c7f23939816ce8e8..214ca3dc492c31d4c683790a6ae051be467401c9 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -44,17 +44,11 @@ class Executor { explicit Executor(const platform::Place& place); -#ifdef PADDLE_WITH_DISTRIBUTE /* - * Sending signal to pserver to mark current pass started. + * Close this Executor. + * Calling this method will send complete messages to all pserver instances. */ - void BeginPass(); - - /* - * Sending signal to pserver to mark current pass finished. - */ - void EndPass(); -#endif + void Close(); /* @Brief * Runtime evaluation of the given ProgramDesc under certain Scope diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 6555b8101a90bba8351d2c82313ab12e572a01ee..1612927055dd4ec5ee2220bc2b285e8d9b640ea8 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -18,7 +18,7 @@ if(WITH_GRPC) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(grpc_serde_test SRCS grpc_serde_test.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) - cc_test(grpc_server_test SRCS rpc_server_test.cc + cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL) return() endif() diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 52c4bc1e7965323438de959d5eb1f3b4ef4f4cfe..265f964ddc682868c64669744b130aebbbf86692 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -36,20 +36,16 @@ void GRPCClient::InitEventLoop() { client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } -void GRPCClient::SendBeginPass() { - for (auto& it : channels_) { - VLOG(3) << "send begin pass to: " << it.first; - this->AsyncSendBeginPass(it.first); - } - this->Wait(); -} - -void GRPCClient::SendEndPass() { - for (auto& it : channels_) { - VLOG(3) << "send end pass to " << it.first; - this->AsyncSendEndPass(it.first); +void GRPCClient::SendComplete() { + std::unique_lock lk(completed_mutex_); + if (!completed_) { + for (auto& it : channels_) { + VLOG(3) << "send complete message to " << it.first; + this->AsyncSendComplete(it.first); + } + PADDLE_ENFORCE(this->Wait(), "internal grpc error"); + completed_ = true; } - this->Wait(); } GRPCClient::~GRPCClient() { @@ -239,32 +235,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, req_count_++; } -void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); sendrecv::VariableMessage req; - req.set_varname(BEGIN_PASS_MESSAGE); + req.set_varname(COMPLETE_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } -void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); - - FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - s->Prepare(time_out); - - sendrecv::VariableMessage req; - req.set_varname(END_PASS_MESSAGE); - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; -} - void GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 11de84d9e265b2ca75d6d72a1d1e8797763f96a5..8351d825f817437e1b3691e916952dd9a86af491 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { class GRPCClient : public RPCClient { public: - GRPCClient() : ok_(true) {} + GRPCClient() : ok_(true), completed_(false) {} virtual ~GRPCClient(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, @@ -201,17 +201,12 @@ class GRPCClient : public RPCClient { void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) override; - void AsyncSendBeginPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendEndPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; + void AsyncSendComplete(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; - void SendBeginPass() override; - - void SendEndPass() override; + void SendComplete() override; protected: void InitImpl() override; @@ -238,6 +233,10 @@ class GRPCClient : public RPCClient { // mutex for GetChannel thread safety std::mutex chan_mutex_; DISABLE_COPY_AND_ASSIGN(GRPCClient); + + // mutex for sending complete message only once + std::mutex completed_mutex_; + bool completed_; }; } // namespace distributed diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 3d61171dff98d6752be98b4b90577bfd059525ab..64ac7281848f91302bc0aa3cb81dd198e56fb653 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" -#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV" -#define END_PASS_MESSAGE "END_PASS@RECV" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index f1f84072d47e58eaa81dd66dc018e17b182bb57b..55995783c6eab10632ab2a5bca64ca856f000df1 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname, if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; rpc_server_->IncreaseBatchBarrier(kRequestSend); - } else if (varname == BEGIN_PASS_MESSAGE) { - VLOG(3) << "sync: recv begin pass message"; - rpc_server_->WaitCond(kRequestSend); - rpc_server_->BeginPass(); + } else if (varname == COMPLETE_MESSAGE) { + VLOG(3) << "sync: recv complete message"; + rpc_server_->Complete(); } else { VLOG(3) << "sync: received var_name: " << varname; rpc_server_->WaitCond(kRequestSend); @@ -94,14 +93,12 @@ bool RequestGetHandler::Handle(const std::string& varname, if (varname == FETCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv fetch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestGet); - } else if (varname == END_PASS_MESSAGE) { - rpc_server_->EndPass(); } else { rpc_server_->WaitCond(kRequestGet); *outvar = scope_->FindVar(varname); } } else { - if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) { + if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { *outvar = scope_->FindVar(varname); } } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 4d87376fbf776e29156b78d826f5012bc53460df..22a022a5d25e5c6628b80294494b87ca105a04c7 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -60,17 +60,13 @@ class RPCClient { const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) = 0; - virtual void AsyncSendBeginPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual void AsyncSendComplete(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) = 0; - virtual void AsyncSendEndPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - // BeginePass/EndPass tells all the pserver that start/end a pass, so that - // the pserver can increase/reduce it's barrier count, and continue to train + // Complete tells all the pserver instances that finishe the training, + // the pserver can reduce it's barrier count, and continue to train // with other trainers. - virtual void SendBeginPass() = 0; - virtual void SendEndPass() = 0; + virtual void SendComplete() = 0; virtual bool Wait() = 0; diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index d49ee34eeaf4e80f6fd4f8cdc548cc2b938d0f2a..83b14fa64d735d80f43bf55c798cddb2f3ea7032 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { } } -void RPCServer::BeginPass() { - VLOG(4) << "RPCServer begin increase pass barrier"; - { - std::unique_lock lock(mutex_); - client_num_++; - VLOG(4) << "increase client_num to: " << client_num_; - } - barrier_cond_.notify_all(); -} - -void RPCServer::EndPass() { - VLOG(4) << "RPCServer begin increase pass barrier"; +void RPCServer::Complete() { { std::unique_lock lock(mutex_); client_num_--; @@ -87,6 +76,11 @@ void RPCServer::EndPass() { barrier_cond_.notify_all(); } +int RPCServer::GetClientNum() { + std::unique_lock lock(mutex_); + return client_num_; +} + void RPCServer::ResetBarrierCounter() { VLOG(3) << "RPCServer ResetBarrierCounter "; std::unique_lock lock(mutex_); diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index 833991c8aa6e7cfd10f2aa52f9218be7ff8ccebf..fd914d7a72e61bc9472876c433b65598ef5b1980 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -44,7 +44,7 @@ class RPCServer { int GetSelectedPort() const { return selected_port_; } - int GetClientNum() const; + int GetClientNum(); void SavePort() const; @@ -64,8 +64,7 @@ class RPCServer { void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); - void BeginPass(); - void EndPass(); + void Complete(); void ResetBarrierCounter(); diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index a0693cffabcc561b0adfafc2c49027a890dd5efc..9f2360ec70d2ce5d4e16435595e109c1bf04fd13 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } } -void StartServer() { +void StartServer(const std::string& rpc_name) { framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; @@ -107,14 +107,14 @@ void StartServer() { std::shared_ptr> prefetch_var_name_to_prepared; prefetch_var_name_to_prepared[in_var_name] = prepared[0]; + g_req_handler->SetProgram(&program); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetDevCtx(&ctx); g_req_handler->SetScope(&scope); g_req_handler->SetExecutor(&exe); - g_rpc_service->RegisterRPC(distributed::kRequestPrefetch, - g_req_handler.get()); + g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( @@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) { distributed::RPCClient* client = distributed::RPCClient::GetInstance(); - std::thread server_thread(StartServer); + std::thread server_thread(StartServer, distributed::kRequestPrefetch); g_rpc_service->WaitServerReady(); int port = g_rpc_service->GetSelectedPort(); @@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) { g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); } + +TEST(COMPLETE, CPU) { + g_req_handler.reset(new distributed::RequestSendHandler(true)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(); + PADDLE_ENFORCE(client != nullptr); + std::thread server_thread(StartServer, distributed::kRequestSend); + g_rpc_service->WaitServerReady(); + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + client->AsyncSendComplete(ep); + client->Wait(); + + EXPECT_EQ(g_rpc_service->GetClientNum(), 1); + + g_rpc_service->ShutDown(); + server_thread.join(); + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); +} diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3e13e7b1ffebf92301df69084b058ca55783e578..ee1c8d46ddfb4f0c09591bb78dc720555dc735b4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) -#ifdef PADDLE_WITH_DISTRIBUTE - .def("begin_pass", &Executor::BeginPass) - .def("end_pass", &Executor::EndPass) -#endif + .def("close", &Executor::Close) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars) { pybind11::gil_scoped_release release; diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index f9e600cb4cb252baead87025db0e0db71e8169d2..4178971398c953236bf8de4d5cb6e93d0e33380c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -247,6 +247,7 @@ class Executor(object): p.set_place(place) self.executor = core.Executor(p) self.program_caches = dict() + self._closed = False def as_lodtensor(self, data): """ @@ -348,11 +349,23 @@ class Executor(object): ] return outs - def begin_pass(self): - self.executor.begin_pass() + def close(self): + """ + Close this executor. - def end_pass(self): - self.executor.end_pass() + You can no long use this executor after calling this method. + For the distributed training, this method would free the resource on PServers related to + the current Trainer. + + Example: + >>> cpu = core.CPUPlace() + >>> exe = Executor(cpu) + >>> ... + >>> exe.close() + """ + if not self._closed: + self.executor.close() + self._closed = True def run(self, program=None, @@ -405,6 +418,10 @@ class Executor(object): >>> feed={'X': x}, >>> fetch_list=[loss.name]) """ + + if self._closed: + raise RuntimeError("Attempted to use a closed Executor") + if feed is None: feed = {} if not isinstance(feed, dict):