From b5a410466c70750b763d5be4c3238b73c9190c90 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 7 Oct 2019 18:24:42 +0800 Subject: [PATCH] Trainer heartbeat for async mode (#19600) Heartbeat for distributed async training. --- .../operators/distributed/CMakeLists.txt | 5 +- .../operators/distributed/grpc/grpc_client.cc | 1 + .../distributed/heart_beat_monitor.cc | 97 +++++++++++++ .../distributed/heart_beat_monitor.h | 136 ++++++++++++++++++ .../distributed/heart_beat_monitor_test.cc | 57 ++++++++ .../distributed/request_handler_impl.cc | 7 +- .../operators/distributed/rpc_server_test.cc | 4 + .../distributed_ops/listen_and_serv_op.cc | 21 ++- python/paddle/fluid/__init__.py | 2 + .../fluid/transpiler/distribute_transpiler.py | 1 + 10 files changed, 327 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/distributed/heart_beat_monitor.cc create mode 100644 paddle/fluid/operators/distributed/heart_beat_monitor.h create mode 100644 paddle/fluid/operators/distributed/heart_beat_monitor_test.cc diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 8909135d234..b8b82180b3a 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -12,6 +12,9 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool) cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) +cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) +cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) + # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") if(WITH_GRPC) @@ -23,7 +26,7 @@ if(WITH_GRPC) collective_client.cc collective_server.cc ${GRPC_SRCS} PROTO send_recv.proto - DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder) + DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 053fe202fe9..de61400fdf6 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -392,6 +392,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, s->Prepare(h, time_out); sendrecv::VariableMessage req; + req.set_trainer_id(trainer_id_); req.set_varname(COMPLETE_MESSAGE); platform::RecordRPCEvent record_event(method); diff --git a/paddle/fluid/operators/distributed/heart_beat_monitor.cc b/paddle/fluid/operators/distributed/heart_beat_monitor.cc new file mode 100644 index 00000000000..6736ea4336b --- /dev/null +++ b/paddle/fluid/operators/distributed/heart_beat_monitor.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" +#include // NOLINT +#include + +namespace paddle { +namespace operators { +namespace distributed { + +DEFINE_int32(worker_update_interval_secs, 900, + " the longest time interval between the worker update variables"); + +inline int GetCurrentUS() { + // current date/time based on current system + time_t t = std::time(0); + int now = static_cast(t); + return now; +} + +void HeartBeatMonitor::Update(const int worker_id, std::string be_monitored_var, + WorkerStatus status) { + if (status == UNINITED) { + LOG(WARNING) << "HeartBeatMonitor receive UNINITED status can not be used " + "in Update, something error"; + } + + if (!is_chief_) { + return; + } + + if ((be_monitored_var == be_monitored_var_ && status == RUNNING) || + status == COMPLETED) { + auto timestamp = GetCurrentUS(); + UnderMonitoredWorker& worker = worker_status_map_.at(worker_id); + + if (worker.status != COMPLETED) { + worker.status = status; + } + worker.timestamp = timestamp; + return; + } +} + +void HeartBeatMonitor::LostWorkerMonitor() { + VLOG(1) << "worker heartbeat monitor start at No.0 parameter server"; + while (running_) { + for (int id = 0; id < workers_; ++id) { + auto& worker = worker_status_map_.at(id); + + if (worker.status == UNINITED) { + VLOG(4) << "worker " << worker.id << " is under UNINITED"; + continue; + } + if (worker.status == COMPLETED) { + VLOG(4) << "worker " << worker.id << " is under COMPLETED"; + continue; + } + + auto timestamp = GetCurrentUS(); + + VLOG(4) << "worker " << worker.id << " status is " << worker.status + << " timestamp is " << worker.timestamp << " the interval is " + << timestamp - worker.timestamp; + + if (timestamp - worker.timestamp >= FLAGS_worker_update_interval_secs) { + PADDLE_THROW( + "the latest update of worker %d is %d secs ago, we doubt the " + "the worker is not alive and this may have a bad effect on the " + "fitting result, please check", + worker.id, FLAGS_worker_update_interval_secs); + } + } + + std::this_thread::sleep_for(std::chrono::milliseconds(30 * 1000)); + } + VLOG(1) << "worker heartbeat monitor stopped, thread exit"; +} + +std::once_flag HeartBeatMonitor::init_flag_; +std::unique_ptr HeartBeatMonitor::monitor_(nullptr); + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/heart_beat_monitor.h b/paddle/fluid/operators/distributed/heart_beat_monitor.h new file mode 100644 index 00000000000..639785ba513 --- /dev/null +++ b/paddle/fluid/operators/distributed/heart_beat_monitor.h @@ -0,0 +1,136 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include // NOLINT + +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace distributed { + +enum WorkerStatus { UNINITED = 0, RUNNING, COMPLETED }; + +struct UnderMonitoredWorker { + int id; + WorkerStatus status; + int timestamp; + + UnderMonitoredWorker() {} + + explicit UnderMonitoredWorker(int worker_id) { + this->id = worker_id; + this->status = UNINITED; + this->timestamp = 0; + } +}; + +class HeartBeatMonitor { + public: + explicit HeartBeatMonitor(int workers, bool is_chief, + std::string be_monitored_var) + : workers_(workers), + is_chief_(is_chief), + be_monitored_var_(be_monitored_var), + running_(true) { + PADDLE_ENFORCE_GT(workers, 0, "trainers must have one or more"); + + for (auto worker_id = 0; worker_id < workers; worker_id++) { + UnderMonitoredWorker worker(worker_id); + worker_status_map_[worker_id] = std::move(worker); + } + + // we define the No.0 pserver is the first parameter server + // only No.0 will check the heartbeat of all trainers + if (is_chief) { + monitor_thread_.reset(new std::thread( + std::bind(&HeartBeatMonitor::LostWorkerMonitor, this))); + } + } + + ~HeartBeatMonitor() { + running_ = false; + if (monitor_thread_) monitor_thread_->join(); + } + + static void Init(int workers, bool is_chief, std::string be_monitored_var) { + std::call_once(init_flag_, &HeartBeatMonitor::InitImpl, workers, is_chief, + be_monitored_var); + } + + static HeartBeatMonitor* GetInstance() { + if (monitor_ == nullptr) { + PADDLE_THROW( + "HeartBeatMonitor is not inited, call " + "HeartBeatMonitor::Init first"); + } + return monitor_.get(); + } + + void Stop() { + running_ = false; + if (!monitor_) { + VLOG(0) << "HeartBeatMonitor is not inited, do nothing"; + } else { + if (monitor_thread_) { + monitor_thread_->join(); + monitor_thread_.reset(nullptr); + } + } + } + + void Update(const int worker_id, std::string be_monitored_var, + WorkerStatus status); + + void LostWorkerMonitor(); + + private: + // Init is called by GetInstance. + static void InitImpl(int workers, bool is_chief, + std::string be_monitored_var) { + if (monitor_ == nullptr) { + monitor_.reset(new HeartBeatMonitor(workers, is_chief, be_monitored_var)); + } + } + + static std::once_flag init_flag_; + static std::unique_ptr monitor_; + + int workers_; + bool is_chief_; + std::string be_monitored_var_; + std::unordered_map worker_status_map_; + std::unique_ptr monitor_thread_{nullptr}; + std::mutex mutex_; + bool running_ = false; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/heart_beat_monitor_test.cc b/paddle/fluid/operators/distributed/heart_beat_monitor_test.cc new file mode 100644 index 00000000000..916ee43ffbf --- /dev/null +++ b/paddle/fluid/operators/distributed/heart_beat_monitor_test.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" + +#include +#include // NOLINT + +#include "gtest/gtest.h" + +namespace paddle { +namespace operators { +namespace distributed { + +void run(HeartBeatMonitor* monitor) { monitor->LostWorkerMonitor(); } + +TEST(HeartBeatMonitor, All) { + int trainers = 10; + int pserver_id = 0; + std::string var = "nce_w@GRAD.block0"; + std::string var2 = "nce_w@GRAD.block2"; + + HeartBeatMonitor::Init(trainers, pserver_id == 0, var); + + auto* monitor = HeartBeatMonitor::GetInstance(); + + std::vector ids{1, 3, 5, 7}; + + for (auto& id : ids) { + monitor->Update(id, var, RUNNING); + } + + monitor->Update(9, var2, RUNNING); + monitor->Update(2, var, COMPLETED); + + std::thread t(run, monitor); + t.detach(); + + std::this_thread::sleep_for(std::chrono::milliseconds(45 * 1000)); + + monitor->Stop(); +} + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 613e5251c12..ca150f70c74 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -22,12 +22,14 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/split.h" +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" +#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" + namespace paddle { namespace operators { namespace distributed { @@ -51,6 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname, rpc_server_->IncreaseBatchBarrier(kRequestSend); } else if (varname == COMPLETE_MESSAGE) { VLOG(3) << "sync: recv complete message"; + HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); rpc_server_->Complete(); } else { // Async @@ -61,6 +64,7 @@ bool RequestSendHandler::Handle(const std::string& varname, "async mode should not recv BATCH_BARRIER_MESSAGE or " "COMPLETE_MESSAGE"); } + HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); std::string run_varname = varname; @@ -82,6 +86,7 @@ bool RequestSendHandler::Handle(const std::string& varname, } executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), scope); + return true; } else { // sync rpc_server_->WaitCond(kRequestSend); diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index 45e97d966fc..df52d74ed58 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_server.h" @@ -116,6 +117,9 @@ void StartServer(const std::string& rpc_name) { g_req_handler->SetExecutor(&exe); g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); + + distributed::HeartBeatMonitor::Init(2, true, "w@grad"); + g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index 14b53086d1c..ba55d5c2f3d 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" +#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" @@ -338,14 +339,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, bool sync_mode = Attr("sync_mode"); bool dc_sgd = Attr("dc_asgd"); auto fan_in = Attr("Fanin"); + auto pserver_id = Attr("pserver_id"); auto inputs = Inputs("X"); PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); int checkpoint_block_id = Attr(kCheckpointBlockId); - VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in - << ", end_point:" << endpoint + VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode + << ", fan_in:" << fan_in << ", end_point:" << endpoint << ", checkpoint_block_id: " << checkpoint_block_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -466,6 +468,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, } else { distributed::AsyncSparseParamUpdateRecorder::Init( fan_in, sparse_grad_name_to_param_name); + + VLOG(2) << "RunAsyncLoop"; + auto grad_to_block_id_str = + Attr>("grad_to_block_id"); + + if (grad_to_block_id_str.size() == 0) { + VLOG(0) << "there are no gradients on this parameter server"; + } else { + std::vector pieces; + split(grad_to_block_id_str[0], ':', &pieces); + distributed::HeartBeatMonitor::Init(fan_in, pserver_id == 0, pieces[0]); + } RunAsyncLoop(&executor, program, &recv_scope); } } @@ -482,6 +496,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr("pserver_id", + "(int, default -1), the parameter server index id") + .SetDefault(-1); AddAttr>( "grad_to_block_id", "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1c05d0a4b9a..5674cd08b71 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -189,6 +189,8 @@ def __bootstrap__(): read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_disable_reuse_port') + read_env_flags.append('worker_update_interval_secs') + # env for communicator read_env_flags.append('communicator_independent_recv_thread') read_env_flags.append('communicator_send_queue_size') diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f84f42f05f9..e48d7fa76e9 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1193,6 +1193,7 @@ class DistributeTranspiler(object): attrs = { "optimize_blocks": optimize_blocks, "endpoint": endpoint, + "pserver_id": self.pserver_endpoints.index(endpoint), "Fanin": self.trainer_num, "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, -- GitLab