diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index ee3b3e3ccbd6593a17e5278bffb2b3010d23e925..bb9c93480df6a5f07740f34df45296ef8567fdb2 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -208,9 +208,9 @@ void AsyncGRPCServer::WaitClientGet(int count) { } } -bool AsyncGRPCServer::WaitServerReady() { +void AsyncGRPCServer::WaitServerReady() { std::unique_lock lock(this->mutex_ready_); - condition_ready_.wait(lock, [&] { return this->ready_ == 1; }); + condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); } void AsyncGRPCServer::RunSyncUpdate() { diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index d7c06fc1814380c2afc2cc3f7886199d2868c723..7f9cae21ccca8dd51f9fbe98148d01a51ac6eb84 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -47,7 +47,7 @@ class AsyncGRPCServer final { explicit AsyncGRPCServer(const std::string &address, bool sync_mode) : address_(address), sync_mode_(sync_mode), ready_(0) {} - bool WaitServerReady(); + void WaitServerReady(); void RunSyncUpdate(); // functions to sync server barrier status. @@ -120,7 +120,7 @@ class AsyncGRPCServer final { framework::Executor *executor_; int selected_port_; - std::mutext mutex_ready_; + std::mutex mutex_ready_; std::condition_variable condition_ready_; int ready_; }; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 0a4b6a08e58f368346ff8b7087882f395bce0fef..350c9c8563e884e7e993e9eb44d3b87c4116f2e4 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -80,12 +80,7 @@ static void ParallelExecuteBlocks( for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } -static void SavePort(std::shared_ptr rpc_service) { - std::ofstream port_file; - port_file.open("/tmp/paddle.selected_port"); - port_file << rpc_service->GetSelectedPort(); - port_file.close(); -} +std::atomic_int ListenAndServOp::selected_port_{0}; ListenAndServOp::ListenAndServOp(const std::string &type, const framework::VariableNameMap &inputs, @@ -93,15 +88,27 @@ ListenAndServOp::ListenAndServOp(const std::string &type, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} -int ListenAndServOp::GetSelectedPort() const { - return rpc_service_->GetSelectedPort(); -} - void ListenAndServOp::Stop() { rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); server_thread_->join(); } +void ListenAndServOp::SavePort(const std::string &file_path) const { + // NOTE: default write file to /tmp/paddle.selected_port + selected_port_ = rpc_service_->GetSelectedPort(); + + std::ofstream port_file; + port_file.open(file_path); + port_file << selected_port_.load(); + port_file.close(); + VLOG(4) << "selected port written to " << file_path; +} + +void ListenAndServOp::WaitServerReady() { + while (selected_port_.load() == 0) { + } +} + void ListenAndServOp::RunSyncLoop(framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, @@ -265,23 +272,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, } // while(true) } -void ListenAndServOp::StartServerThread() { - server_thread_.reset(new std::thread( - std::bind(&ListenAndServOp::ServerThreadEntry, this, rpc_service_))); -} - -void ListenAndServOp::ServerThreadEntry( - std::shared_ptr service) { - service->RunSyncUpdate(); - VLOG(4) << "RunServer thread end"; - - { - std::lock_guard lock(this->barrier_mutex_); - barrier_cond_step_ = cond; - } - barrier_condition_.notify_all(); -} - void ListenAndServOp::RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -315,9 +305,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); VLOG(3) << "wait server thread to become ready..."; + rpc_service_->WaitServerReady(); // Write to a file of server selected port for python use. - SavePort(rpc_service_); + SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block); } else { diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index c85569acdcd82d8e4e3c418d4c0b98e811adcf4c..87c0df2a8a10d9c8e464f27b06237c099ee2bda6 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include @@ -39,8 +40,6 @@ class ListenAndServOp : public framework::OperatorBase { const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs); - int GetSelectedPort() const; - void RunSyncLoop(framework::Executor* executor, framework::ProgramDesc* program, framework::Scope* recv_scope, @@ -51,20 +50,25 @@ class ListenAndServOp : public framework::OperatorBase { framework::Scope* recv_scope, framework::BlockDesc* prefetch_block) const; - void StartServerThread(); + void SavePort( + const std::string& file_path = "/tmp/paddle.selected_port") const; + + void WaitServerReady(); - void ServerThreadEntry(); + int GetSelectedPort() { return selected_port_; } void Stop() override; void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override; + static void ResetPort() { selected_port_ = 0; } + protected: mutable std::shared_ptr rpc_service_; mutable std::shared_ptr server_thread_; - std::mutext server_ready_mutex_; - std::condition_variable server_ready_; + // FIXME(wuyi): it's static so that the operator can be cloned. + static std::atomic_int selected_port_; }; } // namespace operators diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index d2e1f3cb2ff9c8254cd4815a0f8750966a6e161c..a0b5a390db925d8e94d934e880f63ee2f565c2af 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -116,6 +116,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, void StartServerNet(bool is_sparse) { f::Scope scope; p::CPUPlace place; + VLOG(4) << "before init tensor"; if (is_sparse) { InitSelectedRowsInScope(place, &scope); } else { @@ -129,6 +130,7 @@ void StartServerNet(bool is_sparse) { auto *prefetch_block = program.AppendBlock(root_block); // X for server side tensors, RX for received tensors, must be of same shape. AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); + VLOG(4) << "before attr"; f::AttributeMap attrs; attrs.insert({"endpoint", std::string("127.0.0.1:0")}); @@ -139,15 +141,19 @@ void StartServerNet(bool is_sparse) { attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"grad_to_block_id", std::vector({""})}); attrs.insert({"sync_mode", true}); + VLOG(4) << "before init op"; listen_and_serv_op = f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); + VLOG(4) << "before run op"; listen_and_serv_op->Run(scope, place); LOG(INFO) << "server exit"; } TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false); - sleep(5); // wait server to start + // wait server to start + static_cast(listen_and_serv_op.get()) + ->WaitServerReady(); // local net f::Scope scope; p::CPUPlace place; @@ -181,11 +187,13 @@ TEST(SendRecvOp, CPUDense) { listen_and_serv_op->Stop(); server_thread.join(); listen_and_serv_op.reset(nullptr); + paddle::operators::ListenAndServOp::ResetPort(); } TEST(SendRecvOp, CPUSparse) { std::thread server_thread(StartServerNet, true); - sleep(3); // wait server to start + static_cast(listen_and_serv_op.get()) + ->WaitServerReady(); // local net f::Scope scope; p::CPUPlace place; @@ -226,4 +234,5 @@ TEST(SendRecvOp, CPUSparse) { listen_and_serv_op->Stop(); server_thread.join(); listen_and_serv_op.reset(); + paddle::operators::ListenAndServOp::ResetPort(); }