未验证 提交 b5a41046 编写于 作者: T tangwei12 提交者: GitHub

Trainer heartbeat for async mode (#19600)

Heartbeat for distributed async training.
上级 76ba55e8
...@@ -12,6 +12,9 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O ...@@ -12,6 +12,9 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool) cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC) if(WITH_GRPC)
...@@ -23,7 +26,7 @@ if(WITH_GRPC) ...@@ -23,7 +26,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc collective_client.cc collective_server.cc
${GRPC_SRCS} ${GRPC_SRCS}
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder) DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......
...@@ -392,6 +392,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -392,6 +392,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_trainer_id(trainer_id_);
req.set_varname(COMPLETE_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <chrono> // NOLINT
#include <ctime>
namespace paddle {
namespace operators {
namespace distributed {
DEFINE_int32(worker_update_interval_secs, 900,
" the longest time interval between the worker update variables");
inline int GetCurrentUS() {
// current date/time based on current system
time_t t = std::time(0);
int now = static_cast<int>(t);
return now;
}
void HeartBeatMonitor::Update(const int worker_id, std::string be_monitored_var,
WorkerStatus status) {
if (status == UNINITED) {
LOG(WARNING) << "HeartBeatMonitor receive UNINITED status can not be used "
"in Update, something error";
}
if (!is_chief_) {
return;
}
if ((be_monitored_var == be_monitored_var_ && status == RUNNING) ||
status == COMPLETED) {
auto timestamp = GetCurrentUS();
UnderMonitoredWorker& worker = worker_status_map_.at(worker_id);
if (worker.status != COMPLETED) {
worker.status = status;
}
worker.timestamp = timestamp;
return;
}
}
void HeartBeatMonitor::LostWorkerMonitor() {
VLOG(1) << "worker heartbeat monitor start at No.0 parameter server";
while (running_) {
for (int id = 0; id < workers_; ++id) {
auto& worker = worker_status_map_.at(id);
if (worker.status == UNINITED) {
VLOG(4) << "worker " << worker.id << " is under UNINITED";
continue;
}
if (worker.status == COMPLETED) {
VLOG(4) << "worker " << worker.id << " is under COMPLETED";
continue;
}
auto timestamp = GetCurrentUS();
VLOG(4) << "worker " << worker.id << " status is " << worker.status
<< " timestamp is " << worker.timestamp << " the interval is "
<< timestamp - worker.timestamp;
if (timestamp - worker.timestamp >= FLAGS_worker_update_interval_secs) {
PADDLE_THROW(
"the latest update of worker %d is %d secs ago, we doubt the "
"the worker is not alive and this may have a bad effect on the "
"fitting result, please check",
worker.id, FLAGS_worker_update_interval_secs);
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(30 * 1000));
}
VLOG(1) << "worker heartbeat monitor stopped, thread exit";
}
std::once_flag HeartBeatMonitor::init_flag_;
std::unique_ptr<HeartBeatMonitor> HeartBeatMonitor::monitor_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
enum WorkerStatus { UNINITED = 0, RUNNING, COMPLETED };
struct UnderMonitoredWorker {
int id;
WorkerStatus status;
int timestamp;
UnderMonitoredWorker() {}
explicit UnderMonitoredWorker(int worker_id) {
this->id = worker_id;
this->status = UNINITED;
this->timestamp = 0;
}
};
class HeartBeatMonitor {
public:
explicit HeartBeatMonitor(int workers, bool is_chief,
std::string be_monitored_var)
: workers_(workers),
is_chief_(is_chief),
be_monitored_var_(be_monitored_var),
running_(true) {
PADDLE_ENFORCE_GT(workers, 0, "trainers must have one or more");
for (auto worker_id = 0; worker_id < workers; worker_id++) {
UnderMonitoredWorker worker(worker_id);
worker_status_map_[worker_id] = std::move(worker);
}
// we define the No.0 pserver is the first parameter server
// only No.0 will check the heartbeat of all trainers
if (is_chief) {
monitor_thread_.reset(new std::thread(
std::bind(&HeartBeatMonitor::LostWorkerMonitor, this)));
}
}
~HeartBeatMonitor() {
running_ = false;
if (monitor_thread_) monitor_thread_->join();
}
static void Init(int workers, bool is_chief, std::string be_monitored_var) {
std::call_once(init_flag_, &HeartBeatMonitor::InitImpl, workers, is_chief,
be_monitored_var);
}
static HeartBeatMonitor* GetInstance() {
if (monitor_ == nullptr) {
PADDLE_THROW(
"HeartBeatMonitor is not inited, call "
"HeartBeatMonitor::Init first");
}
return monitor_.get();
}
void Stop() {
running_ = false;
if (!monitor_) {
VLOG(0) << "HeartBeatMonitor is not inited, do nothing";
} else {
if (monitor_thread_) {
monitor_thread_->join();
monitor_thread_.reset(nullptr);
}
}
}
void Update(const int worker_id, std::string be_monitored_var,
WorkerStatus status);
void LostWorkerMonitor();
private:
// Init is called by GetInstance.
static void InitImpl(int workers, bool is_chief,
std::string be_monitored_var) {
if (monitor_ == nullptr) {
monitor_.reset(new HeartBeatMonitor(workers, is_chief, be_monitored_var));
}
}
static std::once_flag init_flag_;
static std::unique_ptr<HeartBeatMonitor> monitor_;
int workers_;
bool is_chief_;
std::string be_monitored_var_;
std::unordered_map<int, UnderMonitoredWorker> worker_status_map_;
std::unique_ptr<std::thread> monitor_thread_{nullptr};
std::mutex mutex_;
bool running_ = false;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <algorithm>
#include <thread> // NOLINT
#include "gtest/gtest.h"
namespace paddle {
namespace operators {
namespace distributed {
void run(HeartBeatMonitor* monitor) { monitor->LostWorkerMonitor(); }
TEST(HeartBeatMonitor, All) {
int trainers = 10;
int pserver_id = 0;
std::string var = "nce_w@GRAD.block0";
std::string var2 = "nce_w@GRAD.block2";
HeartBeatMonitor::Init(trainers, pserver_id == 0, var);
auto* monitor = HeartBeatMonitor::GetInstance();
std::vector<int> ids{1, 3, 5, 7};
for (auto& id : ids) {
monitor->Update(id, var, RUNNING);
}
monitor->Update(9, var2, RUNNING);
monitor->Update(2, var, COMPLETED);
std::thread t(run, monitor);
t.detach();
std::this_thread::sleep_for(std::chrono::milliseconds(45 * 1000));
monitor->Stop();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -22,12 +22,14 @@ ...@@ -22,12 +22,14 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
...@@ -51,6 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -51,6 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) { } else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message"; VLOG(3) << "sync: recv complete message";
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
rpc_server_->Complete(); rpc_server_->Complete();
} else { } else {
// Async // Async
...@@ -61,6 +64,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -61,6 +64,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or " "async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"); "COMPLETE_MESSAGE");
} }
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
std::string run_varname = varname; std::string run_varname = varname;
...@@ -82,6 +86,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -82,6 +86,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
} }
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope); scope);
return true; return true;
} else { // sync } else { // sync
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
...@@ -116,6 +117,9 @@ void StartServer(const std::string& rpc_name) { ...@@ -116,6 +117,9 @@ void StartServer(const std::string& rpc_name) {
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
distributed::HeartBeatMonitor::Init(2, true, "w@grad");
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...@@ -338,14 +339,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -338,14 +339,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
bool dc_sgd = Attr<bool>("dc_asgd"); bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto pserver_id = Attr<int>("pserver_id");
auto inputs = Inputs("X"); auto inputs = Inputs("X");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_block_id = Attr<int>(kCheckpointBlockId); int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode
<< ", end_point:" << endpoint << ", fan_in:" << fan_in << ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id; << ", checkpoint_block_id: " << checkpoint_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
...@@ -466,6 +468,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -466,6 +468,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
} else { } else {
distributed::AsyncSparseParamUpdateRecorder::Init( distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name); fan_in, sparse_grad_name_to_param_name);
VLOG(2) << "RunAsyncLoop";
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
if (grad_to_block_id_str.size() == 0) {
VLOG(0) << "there are no gradients on this parameter server";
} else {
std::vector<std::string> pieces;
split(grad_to_block_id_str[0], ':', &pieces);
distributed::HeartBeatMonitor::Init(fan_in, pserver_id == 0, pieces[0]);
}
RunAsyncLoop(&executor, program, &recv_scope); RunAsyncLoop(&executor, program, &recv_scope);
} }
} }
...@@ -482,6 +496,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -482,6 +496,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"IP address to listen on.") "IP address to listen on.")
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<int>("pserver_id",
"(int, default -1), the parameter server index id")
.SetDefault(-1);
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"grad_to_block_id", "grad_to_block_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
......
...@@ -189,6 +189,8 @@ def __bootstrap__(): ...@@ -189,6 +189,8 @@ def __bootstrap__():
read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_prefetch_thread_num')
read_env_flags.append('rpc_disable_reuse_port') read_env_flags.append('rpc_disable_reuse_port')
read_env_flags.append('worker_update_interval_secs')
# env for communicator # env for communicator
read_env_flags.append('communicator_independent_recv_thread') read_env_flags.append('communicator_independent_recv_thread')
read_env_flags.append('communicator_send_queue_size') read_env_flags.append('communicator_send_queue_size')
......
...@@ -1193,6 +1193,7 @@ class DistributeTranspiler(object): ...@@ -1193,6 +1193,7 @@ class DistributeTranspiler(object):
attrs = { attrs = {
"optimize_blocks": optimize_blocks, "optimize_blocks": optimize_blocks,
"endpoint": endpoint, "endpoint": endpoint,
"pserver_id": self.pserver_endpoints.index(endpoint),
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册