diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 89d7a62fe9aca3a71ad34b976a186a80174bfd5e..6a8b15a6b60a2e5635dc78fc877f0c8da9a2a998 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -118,6 +118,10 @@ endif() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") +if(WITH_DISTRIBUTE) + add_definitions(-DPADDLE_WITH_DISTRIBUTE) +endif() + if(WITH_GOLANG) # we need to symlink Paddle directory into GOPATH. If we # don't do it and we have code that depends on Paddle, go diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4271e4c1bb6bc7b83f2633191ea2d464f4f56c4c..6bc770580640f242cfce6a9838f00210f785010a 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope -framework_proto glog lod_rank_table feed_fetch_method) +if(WITH_DISTRIBUTE) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +else() + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) +endif() cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index d4d6c34108b9f1e457d8eb0c36d10339b03330bd..4a6f53cba1f46214dbff3058b221f878ecf46613 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -20,6 +20,9 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/operators/detail/grpc_client.h" +#endif #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" @@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} +#ifdef PADDLE_WITH_DISTRIBUTE +void Executor::Complete() { + ::paddle::operators::detail::RPCClient::GetInstance< + ::paddle::operators::detail::GRPCClient>() + ->SendComplete(); +} +#endif + void InitializeVariable(Variable* var, proto::VarType::Type var_type) { if (var_type == proto::VarType::LOD_TENSOR) { var->GetMutable(); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index e6f9c3d31c18f762ef2de269977e0642a79fb174..67a0761dac2a9adcdd0ce2b218c4aa505d688d56 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -44,6 +44,13 @@ class Executor { explicit Executor(const platform::Place& place); +#ifdef PADDLE_WITH_DISTRIBUTE + /* + * Sending signal to pserver to mark current trainer stop. + */ + void Complete(); +#endif + /* @Brief * Runtime evaluation of the given ProgramDesc under certain Scope * diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 6b8373b1509c898e6ae70a18833df39a4898714a..02ffe3651e1deefcf6981c3d304d64b9a01661bf 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() { client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } +void GRPCClient::SendComplete() { + for (auto& it : channels_) { + this->AsyncSendComplete(it.first); + } +} + GRPCClient::~GRPCClient() { Wait(); cq_.Shutdown(); @@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, req_count_++; } +void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(COMPLETE_MESSAGE); + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; +} + void GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 8db73f875e3e2048386e91f6b5efb29b4ee7e193..44000c028b499d9ad1a0e0dd40a5e287cd61d143 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -195,6 +195,8 @@ class GRPCClient : public RPCClient { void Wait() override; + void SendComplete() override; + protected: void InitImpl() override; @@ -204,6 +206,9 @@ class GRPCClient : public RPCClient { void Proceed(); + void AsyncSendComplete(const std::string& ep, + int64_t time_out = RPCClient::rpc_time_out); + std::shared_ptr GetChannel(const std::string& ep); private: diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index fa979024e37f435b918568a1c5e603f8962f9172..595bfe3787a7f0200f3f3ffc3a67fe2ccaba9008 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" +#define COMPLETE_MESSAGE "COMPLETE@RECV" class RPCServer; diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 5f1a346e93b1a0239af77b86d10782d67c403e23..bf277db5fb6a617dc80a765bc0d78665c397eac7 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -49,6 +49,9 @@ bool RequestSendHandler::Handle(const std::string& varname, if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv batch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestSend); + } else if (varname == COMPLETE_MESSAGE) { + VLOG(3) << "sync: recv complete message"; + rpc_server_->DecreaseClientNum(); } else { VLOG(3) << "sync: received var_name: " << varname; if (sync_mode_) { diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h index 7e76ac0348574d4090793b191be0ff3ff8666b37..47c6ffb4fd7a002fc0bd8053fb3314a2fbf18fd3 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -53,6 +53,11 @@ class RPCClient { virtual void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = rpc_time_out) = 0; + // SendComplete tells all the server that current trainer have no more data + // to train, so that the pserver can reduce it's barrier count, and continue + // to train with other trainers. + virtual void SendComplete() = 0; + virtual void Wait() = 0; static constexpr int64_t rpc_time_out = 120 * 1000; diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc index 448763372a8c224cc68319a4a444915896b68234..cd0fe96e2301ee3304fe9a2967df58b9f7072d8d 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -43,7 +43,7 @@ void RPCServer::SavePort() const { void RPCServer::WaitBarrier(const std::string& rpc_name) { std::unique_lock lock(this->mutex_); - barrier_cond_.wait(lock, [=] { + barrier_cond_.wait(lock, [this, &rpc_name] { return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); }); @@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; int b = 0; - { - std::unique_lock lock(mutex_); - b = ++barrier_counter_[rpc_name]; - } - - VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name - << ", barrier_count:" << b << ", fan_in" << client_num_; - + std::unique_lock lock(mutex_); + b = ++barrier_counter_[rpc_name]; if (b >= client_num_) { + lock.unlock(); barrier_cond_.notify_all(); + lock.lock(); } } +void RPCServer::DecreaseClientNum() { + { + std::unique_lock lock(mutex_); + client_num_--; + } + barrier_cond_.notify_all(); +} + void RPCServer::ResetBarrierCounter() { VLOG(3) << "RPCServer ResetBarrierCounter "; std::unique_lock lock(mutex_); diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h index f809c13c726ac2f1c60e8cf84848c4138f631b44..2e3342428cb56c34abaca655d5906668cda8f140 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -60,7 +60,7 @@ class RPCServer { void SetCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); - + void DecreaseClientNum(); void ResetBarrierCounter(); protected: @@ -79,8 +79,7 @@ class RPCServer { std::string bind_address_; std::atomic exit_flag_; int selected_port_; - - const int client_num_; + int client_num_; std::unordered_map rpc_call_map_; std::unordered_map rpc_thread_num_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c88fbef63cf26c671246b15ea9872da0e7a92c1a..bd5c613f8cf794df5dfeb7517ed4350f9b3b6099 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) +#ifdef PADDLE_WITH_DISTRIBUTE + .def("complete", &Executor::Complete) +#endif .def("run", (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & Executor::Run);