From 34865f2de39d9b8884260f8bd578e77c589f4195 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Tue, 12 Jun 2018 12:20:28 +0800 Subject: [PATCH] Trainer send term signal (#11220) * wip * use executor.complete to end trainer * fix build * fix build with distribute off * fix typo * fix cmake typo * fix build --- cmake/configure.cmake | 4 ++++ paddle/fluid/framework/CMakeLists.txt | 9 ++++++-- paddle/fluid/framework/executor.cc | 11 ++++++++++ paddle/fluid/framework/executor.h | 7 ++++++ paddle/fluid/operators/detail/grpc_client.cc | 19 ++++++++++++++++ paddle/fluid/operators/detail/grpc_client.h | 5 +++++ .../fluid/operators/detail/request_handler.h | 1 + .../operators/detail/request_handler_impl.cc | 3 +++ paddle/fluid/operators/detail/rpc_client.h | 5 +++++ paddle/fluid/operators/detail/rpc_server.cc | 22 +++++++++++-------- paddle/fluid/operators/detail/rpc_server.h | 5 ++--- paddle/fluid/pybind/pybind.cc | 3 +++ 12 files changed, 80 insertions(+), 14 deletions(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 89d7a62fe9a..6a8b15a6b60 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -118,6 +118,10 @@ endif() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") +if(WITH_DISTRIBUTE) + add_definitions(-DPADDLE_WITH_DISTRIBUTE) +endif() + if(WITH_GOLANG) # we need to symlink Paddle directory into GOPATH. If we # don't do it and we have code that depends on Paddle, go diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4271e4c1bb6..6bc77058064 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope -framework_proto glog lod_rank_table feed_fetch_method) +if(WITH_DISTRIBUTE) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +else() + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) +endif() cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index d4d6c34108b..4a6f53cba1f 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -20,6 +20,9 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/operators/detail/grpc_client.h" +#endif #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" @@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} +#ifdef PADDLE_WITH_DISTRIBUTE +void Executor::Complete() { + ::paddle::operators::detail::RPCClient::GetInstance< + ::paddle::operators::detail::GRPCClient>() + ->SendComplete(); +} +#endif + void InitializeVariable(Variable* var, proto::VarType::Type var_type) { if (var_type == proto::VarType::LOD_TENSOR) { var->GetMutable(); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index e6f9c3d31c1..67a0761dac2 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -44,6 +44,13 @@ class Executor { explicit Executor(const platform::Place& place); +#ifdef PADDLE_WITH_DISTRIBUTE + /* + * Sending signal to pserver to mark current trainer stop. + */ + void Complete(); +#endif + /* @Brief * Runtime evaluation of the given ProgramDesc under certain Scope * diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 6b8373b1509..02ffe3651e1 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() { client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } +void GRPCClient::SendComplete() { + for (auto& it : channels_) { + this->AsyncSendComplete(it.first); + } +} + GRPCClient::~GRPCClient() { Wait(); cq_.Shutdown(); @@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, req_count_++; } +void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(COMPLETE_MESSAGE); + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; +} + void GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 8db73f875e3..44000c028b4 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -195,6 +195,8 @@ class GRPCClient : public RPCClient { void Wait() override; + void SendComplete() override; + protected: void InitImpl() override; @@ -204,6 +206,9 @@ class GRPCClient : public RPCClient { void Proceed(); + void AsyncSendComplete(const std::string& ep, + int64_t time_out = RPCClient::rpc_time_out); + std::shared_ptr GetChannel(const std::string& ep); private: diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index fa979024e37..595bfe3787a 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" +#define COMPLETE_MESSAGE "COMPLETE@RECV" class RPCServer; diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 5f1a346e93b..bf277db5fb6 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -49,6 +49,9 @@ bool RequestSendHandler::Handle(const std::string& varname, if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv batch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestSend); + } else if (varname == COMPLETE_MESSAGE) { + VLOG(3) << "sync: recv complete message"; + rpc_server_->DecreaseClientNum(); } else { VLOG(3) << "sync: received var_name: " << varname; if (sync_mode_) { diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 7e76ac03485..47c6ffb4fd7 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -53,6 +53,11 @@ class RPCClient { virtual void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = rpc_time_out) = 0; + // SendComplete tells all the server that current trainer have no more data + // to train, so that the pserver can reduce it's barrier count, and continue + // to train with other trainers. + virtual void SendComplete() = 0; + virtual void Wait() = 0; static constexpr int64_t rpc_time_out = 120 * 1000; diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc index 448763372a8..cd0fe96e230 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -43,7 +43,7 @@ void RPCServer::SavePort() const { void RPCServer::WaitBarrier(const std::string& rpc_name) { std::unique_lock lock(this->mutex_); - barrier_cond_.wait(lock, [=] { + barrier_cond_.wait(lock, [this, &rpc_name] { return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); }); @@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; int b = 0; - { - std::unique_lock lock(mutex_); - b = ++barrier_counter_[rpc_name]; - } - - VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name - << ", barrier_count:" << b << ", fan_in" << client_num_; - + std::unique_lock lock(mutex_); + b = ++barrier_counter_[rpc_name]; if (b >= client_num_) { + lock.unlock(); barrier_cond_.notify_all(); + lock.lock(); } } +void RPCServer::DecreaseClientNum() { + { + std::unique_lock lock(mutex_); + client_num_--; + } + barrier_cond_.notify_all(); +} + void RPCServer::ResetBarrierCounter() { VLOG(3) << "RPCServer ResetBarrierCounter "; std::unique_lock lock(mutex_); diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h index f809c13c726..2e3342428cb 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -60,7 +60,7 @@ class RPCServer { void SetCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); - + void DecreaseClientNum(); void ResetBarrierCounter(); protected: @@ -79,8 +79,7 @@ class RPCServer { std::string bind_address_; std::atomic exit_flag_; int selected_port_; - - const int client_num_; + int client_num_; std::unordered_map rpc_call_map_; std::unordered_map rpc_thread_num_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c88fbef63cf..bd5c613f8cf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) +#ifdef PADDLE_WITH_DISTRIBUTE + .def("complete", &Executor::Complete) +#endif .def("run", (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & Executor::Run); -- GitLab