未验证 提交 be6a315f 编写于 作者: T tangwei12 提交者: GitHub

Fix/sync barrier (#25016)

* fix sync barrier with barrier monitor, test=develop
上级 8db66fc3
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -24,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
......@@ -95,22 +93,39 @@ class CGenNCCLIdOp : public framework::OperatorBase {
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
distributed::RequestNotifyHandler notify_h(
distributed::DistributedMode::kSync, -1);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_service->RegisterRPC(distributed::kRequestNotify, &notify_h);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetRPCServer(rpc_service.get());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
notify_h.SetRPCServer(rpc_service.get());
notify_h.SetScope(scope);
notify_h.SetDevCtx(&dev_ctx);
notify_h.SetProgram(&empty_program);
notify_h.SetExecutor(&executor);
distributed::BarrierMonitor::Init(1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(distributed::kRequestSend);
barrier->WaitServerWeakup();
barrier->ServerWeakup();
VLOG(3) << "got nccl id and stop server...";
barrier->Stop();
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
......@@ -123,7 +138,6 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CGenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
......
......@@ -15,6 +15,8 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
cc_library(barrier_monitor SRCS barrier_monitor.cc DEPS enforce simple_threadpool trainer_desc_proto)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
......@@ -26,7 +28,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor)
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor barrier_monitor)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
bool BarrierMonitor::IncreaseBarrier(const int worker_id,
const std::string &barrier) {
release_ = false;
if (barrier == BATCH_BARRIER_MESSAGE) {
VLOG(4) << "BarrierMonitor send queue recv trainer: " << worker_id;
send_barrier_queue->Push(worker_id);
} else if (barrier == FETCH_BARRIER_MESSAGE) {
VLOG(4) << "BarrierMonitor recv queue recv trainer: " << worker_id;
recv_barrier_queue->Push(worker_id);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"unknown Message status %s, only "
"BATCH_BARRIER_MESSAGE/FETCH_BARRIER_MESSAGE",
barrier));
}
return Wait();
}
void BarrierMonitor::DecreaseWorker() {
std::unique_lock<std::mutex> lck(mutex_);
workers_--;
VLOG(1) << "decrement worker num to " << workers_;
}
void BarrierMonitor::Reset(int workers, BarrierType type) {
std::unique_lock<std::mutex> lk(server_mutex_);
workers_ = workers;
barrier_type = type;
send_barrier_queue->Clear();
recv_barrier_queue->Clear();
VLOG(2) << "reset monitor workers: " << workers_ << " type: " << barrier_type;
}
void BarrierMonitor::Monitor() {
while (!IsReady() && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(3) << "sync at first time, wait all trainer ready";
}
while (running_) {
int timer = 0;
if (IsReady()) {
Swap(true);
} else {
VLOG(4) << "running timer: " << timer << " barrier: " << barrier_type
<< " sendQ:" << send_barrier_queue->Size()
<< " recvQ: " << recv_barrier_queue->Size();
timer++;
if (max_wait_ms == -1 || timer < max_wait_ms) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
} else {
VLOG(1) << "time out of " << max_wait_ms
<< ", need barreir: " << barrier_type << " retry";
Swap(false);
}
}
}
}
bool BarrierMonitor::IsReady() {
if (barrier_type == BarrierType::kSendBarrier) {
return static_cast<int>(send_barrier_queue->Size()) == workers_;
} else {
return static_cast<int>(recv_barrier_queue->Size()) == workers_;
}
}
void BarrierMonitor::Swap(bool is_valid) {
std::unique_lock<std::mutex> lck(mutex_);
valid_ = is_valid;
release_ = true;
if (barrier_type == BarrierType::kSendBarrier) {
barrier_type = BarrierType::kRecvBarrier;
send_barrier_queue->Clear();
VLOG(4) << "barrier monitor server clean up queue and barrier";
ServerWeakup();
VLOG(4) << "barrier monitor server weak up sync to do";
WaitServerWeakup();
VLOG(4) << "barrier monitor server weak up sync done";
} else {
barrier_type = BarrierType::kSendBarrier;
recv_barrier_queue->Clear();
VLOG(4) << "barrier monitor server switch to send barrier";
}
worker_cv_.notify_all();
}
void BarrierMonitor::Stop() {
valid_ = true;
release_ = true;
running_ = false;
barrier_type = BarrierType::kRecvBarrier;
send_barrier_queue->Clear();
recv_barrier_queue->Clear();
worker_cv_.notify_all();
server_cv_.notify_all();
if (monitor_thread_) monitor_thread_->join();
monitor_thread_ = nullptr;
}
bool BarrierMonitor::Wait() {
std::unique_lock<std::mutex> lk(mutex_);
worker_cv_.wait(lk, [this] { return (release_); });
return valid_;
}
void BarrierMonitor::WaitServerWeakup() {
std::unique_lock<std::mutex> lk(server_mutex_);
server_cv_.wait(lk);
}
void BarrierMonitor::ServerWeakup() { server_cv_.notify_all(); }
std::once_flag BarrierMonitor::init_flag_;
std::unique_ptr<BarrierMonitor> BarrierMonitor::monitor_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <chrono> // NOLINT
#include <deque>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
enum BarrierType { kSendBarrier, kRecvBarrier };
constexpr int64_t kMaxWaitMS = 120000;
template <typename T>
class BlockingQueueForBarrier {
public:
explicit BlockingQueueForBarrier(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.push_back(elem);
}
worker_cv_.notify_one();
return true;
}
bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.emplace_back(std::move(elem));
}
worker_cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
worker_cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(queue_);
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable worker_cv_;
};
class BarrierMonitor {
public:
explicit BarrierMonitor(int workers)
: BarrierMonitor(workers, BarrierType::kRecvBarrier, kMaxWaitMS) {}
explicit BarrierMonitor(int workers, BarrierType type, int64_t max_wait_times)
: workers_(workers), barrier_type(type), max_wait_ms(max_wait_times) {
PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument(
"trainers must have one or more"));
send_barrier_queue =
std::make_shared<BlockingQueueForBarrier<int>>(workers);
recv_barrier_queue =
std::make_shared<BlockingQueueForBarrier<int>>(workers);
running_ = true;
monitor_thread_.reset(
new std::thread(std::bind(&BarrierMonitor::Monitor, this)));
}
static BarrierMonitor *Init(int workers) {
InitImpl(workers);
return GetInstance();
}
static BarrierMonitor *GetInstance() { return monitor_.get(); }
bool IncreaseBarrier(const int worker_id, const std::string &barrier);
void DecreaseWorker();
int GetWorkerNum() { return workers_; }
void Monitor();
void Swap(bool is_valid);
void Stop();
bool IsReady();
bool Wait();
void WaitServerWeakup();
void ServerWeakup();
void WorkerWeakup();
void Reset(int workers, BarrierType type);
private:
// Init is called by GetInstance.
static void InitImpl(int workers) {
monitor_.reset(new BarrierMonitor(workers));
}
static std::once_flag init_flag_;
static std::unique_ptr<BarrierMonitor> monitor_;
int workers_;
bool running_ = false;
bool valid_ = false;
bool release_ = false;
std::condition_variable worker_cv_;
std::condition_variable server_cv_;
std::mutex server_mutex_;
std::mutex mutex_;
BarrierType barrier_type;
int64_t max_wait_ms;
std::unique_ptr<std::thread> monitor_thread_{nullptr};
std::shared_ptr<BlockingQueueForBarrier<int>> send_barrier_queue;
std::shared_ptr<BlockingQueueForBarrier<int>> recv_barrier_queue;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -306,52 +306,19 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kBatchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
auto h = AsyncDistributeNotify(ep, ctx, *scope, BATCH_BARRIER_MESSAGE);
delete scope;
return h;
}
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = kFetchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
auto h = AsyncDistributeNotify(ep, ctx, *scope, FETCH_BARRIER_MESSAGE);
delete scope;
return h;
}
......@@ -384,27 +351,10 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_trainer_id(trainer_id_);
req.set_varname(COMPLETE_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
auto h = AsyncDistributeNotify(ep, ctx, *scope, COMPLETE_MESSAGE);
delete scope;
return h;
}
......@@ -454,10 +404,21 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer buf;
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
if (var_name_val == BATCH_BARRIER_MESSAGE ||
var_name_val == FETCH_BARRIER_MESSAGE ||
var_name_val == COMPLETE_MESSAGE) {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(var_name_val);
req.set_trainer_id(trainer_id_);
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
} else {
auto* var = p_scope->FindVar(var_name_val);
SerializeToByteBuffer(var_name_val, var, *p_ctx, &buf, "", trainer_id_);
}
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
......@@ -467,7 +428,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", buf,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace paddle {
......@@ -38,161 +39,130 @@ namespace distributed {
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
if (invar == nullptr) {
PADDLE_THROW(platform::errors::NotFound(
"sync: Can not find server side var: %s", varname));
return false;
}
if (HeartBeatMonitor::GetInstance() != nullptr) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
}
if (distributed_mode_ == DistributedMode::kSync) {
return true;
}
rpc_server_->Complete();
} else {
// Async
if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
std::string run_varname = varname;
std::string run_varname = varname;
string::Piece part_piece("@PIECE");
string::Piece var_name_piece = string::Piece(varname);
string::Piece part_piece("@PIECE");
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, part_piece)) {
auto varname_splits = paddle::string::Split(varname, '@');
run_varname = varname_splits[0];
scope->Rename(varname, run_varname);
}
if (string::Contains(var_name_piece, part_piece)) {
auto varname_splits = paddle::string::Split(varname, '@');
PADDLE_ENFORCE_EQ(varname_splits.size(), 3);
run_varname = varname_splits[0];
scope->Rename(varname, run_varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto &grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname;
PADDLE_ENFORCE_NOT_NULL(
invar, platform::errors::NotFound(
"sync: Can not find server side var %s.", varname));
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
if (distributed_mode_ == DistributedMode::kSync) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
*outvar = scope_->FindVar(varname);
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distribute_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
if (enable_dc_asgd_) {
// NOTE: the format is determined by distribute_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : updated_rows) {
sstream << row_id << ", ";
}
} else {
*outvar = scope_->FindVar(varname);
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto *origin_tensor_data = origin_tensor.data<float>();
auto &dims = origin_tensor.dims();
*outvar = scope->Var();
auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0],
platform::errors::OutOfRange(
"expected >= 0 and < %ld, but got %ld.", dims[0],
updated_rows[i]));
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
} else {
*outvar = scope_->FindVar(varname);
}
}
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
......@@ -207,18 +177,19 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
PADDLE_THROW(platform::errors::InvalidArgument(
"GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE));
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) {
......@@ -236,19 +207,20 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true;
}
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
const std::string &out_var_name,
const std::string &table_name) {
PADDLE_ENFORCE_NE(
checkpoint_notify_id, -1,
platform::errors::Unavailable(
"when checkpoint_notify_id = -1, there should be no RPC invoke."));
// TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
auto *lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
......@@ -257,33 +229,56 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
return true;
}
bool RequestNotifyHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname;
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece batch_piece(BATCH_BARRIER_MESSAGE);
string::Piece fetch_piece(FETCH_BARRIER_MESSAGE);
string::Piece complete_piece(COMPLETE_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
if (string::Contains(var_name_piece, batch_piece)) {
return BarrierMonitor::GetInstance()->IncreaseBarrier(
trainer_id, BATCH_BARRIER_MESSAGE);
} else if (string::Contains(var_name_piece, fetch_piece)) {
return BarrierMonitor::GetInstance()->IncreaseBarrier(
trainer_id, FETCH_BARRIER_MESSAGE);
} else if (string::Contains(var_name_piece, complete_piece)) {
if (HeartBeatMonitor::GetInstance() != nullptr) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
}
rpc_server_->Complete();
BarrierMonitor::GetInstance()->DecreaseWorker();
return true;
} else if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
platform::errors::InvalidArgument(
"when lr_decay_block_id = -1, there should be no RPC invoke."));
auto *origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto *send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value =
int64_t *origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
int64_t *send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
return true;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unkown varname %s with RequestNotifyHandler", varname));
}
return true;
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -24,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
......@@ -119,6 +117,7 @@ void StartServer(const std::string& rpc_name) {
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
distributed::HeartBeatMonitor::Init(2, true, "w@grad");
distributed::BarrierMonitor::Init(2);
g_req_handler->SetRPCServer(g_rpc_service.get());
......@@ -164,6 +163,9 @@ TEST(PREFETCH, CPU) {
}
}
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Stop();
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
......@@ -174,20 +176,24 @@ TEST(PREFETCH, CPU) {
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_req_handler.reset(new distributed::RequestNotifyHandler(
distributed::DistributedMode::kSync, -1));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend);
std::thread server_thread(StartServer, distributed::kRequestNotify);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep);
client->Wait();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
EXPECT_EQ(barrier->GetWorkerNum(), 1);
barrier->Stop();
g_rpc_service->ShutDown();
server_thread.join();
......
......@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -21,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
......@@ -30,16 +28,16 @@ namespace operators {
class GenNCCLIdOp : public framework::OperatorBase {
public:
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
GenNCCLIdOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
auto &dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id");
std::vector<std::string> trainers =
......@@ -55,7 +53,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::string endpoint = trainers[trainer_id];
framework::Scope& local_scope = scope.NewScope();
framework::Scope &local_scope = scope.NewScope();
int nccl_comm_num = Attr<int>("nccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
......@@ -171,10 +169,10 @@ class GenNCCLIdOp : public framework::OperatorBase {
}
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx,
const std::string& nccl_id_name,
const std::vector<std::string>& endpoint_list) const {
void GenerateAndSend(framework::Scope *scope,
const platform::DeviceContext &dev_ctx,
const std::string &nccl_id_name,
const std::vector<std::string> &endpoint_list) const {
auto var = scope->FindVar(nccl_id_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
......@@ -182,76 +180,96 @@ class GenNCCLIdOp : public framework::OperatorBase {
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id));
distributed::RPCClient* client =
distributed::RPCClient *client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) {
for (auto &ep : endpoint_list) {
VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name);
}
client->Wait();
for (auto& ep : endpoint_list) {
for (auto &ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep);
}
client->Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(const std::string& endpoint, framework::Scope* scope,
const platform::DeviceContext& dev_ctx, int nccl_comm_num,
void GetIdByServer(const std::string &endpoint, framework::Scope *scope,
const platform::DeviceContext &dev_ctx, int nccl_comm_num,
bool use_hierarchical_allreduce, int trainer_id,
int inter_trainer_id, int exter_trainer_id) const {
// std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
distributed::RequestNotifyHandler notify_h(
distributed::DistributedMode::kSync, -1);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
rpc_service->RegisterRPC(distributed::kRequestNotify, &notify_h);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetRPCServer(rpc_service.get());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
notify_h.SetRPCServer(rpc_service.get());
notify_h.SetScope(scope);
notify_h.SetDevCtx(&dev_ctx);
notify_h.SetProgram(&empty_program);
notify_h.SetExecutor(&executor);
distributed::BarrierMonitor::Init(1);
auto *barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
VLOG(3) << "trainer_id:" << trainer_id
<< " start getting nccl id from trainer 0, nccl_comm_no:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
if (use_hierarchical_allreduce) {
if (inter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
VLOG(3) << "trainer_id:" << trainer_id
<< ", inter_trainer_id:" << inter_trainer_id
<< " start getting nccl id from inter_trainer:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
if (exter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
VLOG(3)
<< "trainer_id:" << trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< " start getting nccl id from exter_trainer 0, nccl_comm_no:"
<< i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
}
......@@ -260,6 +278,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< " got nccl id and stop server...";
barrier->Stop();
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
......@@ -272,7 +291,6 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -25,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
......@@ -38,10 +36,13 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle {
namespace operators {
volatile sig_atomic_t gSignalStatus;
void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
......@@ -126,6 +127,7 @@ void ListenAndServOp::RunSyncLoop(
for (size_t i = 1; i < program->Size(); ++i) {
optimize_blocks_list.push_back(i);
}
auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
// Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran.
......@@ -135,21 +137,15 @@ void ListenAndServOp::RunSyncLoop(
// Trainers will get all parameters from pserver in the
// startup program, so we will wait RequestGet first
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
auto *barrier = distributed::BarrierMonitor::GetInstance();
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to send gradient";
rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend);
barrier->WaitServerWeakup();
if (rpc_service_->IsExit()) {
if (gSignalStatus != 0) {
LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(distributed::kRequestGet);
break;
}
......@@ -180,12 +176,8 @@ void ListenAndServOp::RunSyncLoop(
VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter();
barrier->ServerWeakup();
VLOG(3) << "kRecvBarrier to push params to trainers";
} // while(true)
}
......@@ -281,7 +273,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
while (true) {
if (rpc_service_->IsExit()) {
if (gSignalStatus != 0) {
VLOG(4) << "get exit!rpc_processor break!";
break;
}
......@@ -391,7 +383,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), rpc_send_thread_num);
request_notify_handler_.get(), fan_in * 2);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -440,6 +432,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
......@@ -448,8 +441,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
auto sparse_grad_name_to_param_name_str =
Attr<std::vector<std::string>>(kSparseGradToParam);
for (const auto &sparse_grad_name_and_param_name :
sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces;
......@@ -477,17 +472,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
distributed::BarrierMonitor::Init(fan_in);
if (distributed_mode == distributed::DistributedMode::kSync) {
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use.
SavePort();
CacheVarsType(inputs, recv_scope);
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id);
} else {
......@@ -574,9 +570,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
void SignalHandler::StopAndExit(int signal_num) {
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str());
exit(0);
distributed::BarrierMonitor::GetInstance()->Stop();
gSignalStatus = signal_num;
}
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -20,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
......@@ -42,6 +40,7 @@ namespace string = paddle::string;
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
std::unique_ptr<distributed::RequestNotifyHandler> g_notify_handler;
void StartServer() {
f::Scope scope;
......@@ -52,21 +51,35 @@ void StartServer() {
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
g_req_handler->SetScope(&scope);
g_req_handler->SetDevCtx(&dev_ctx);
g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor);
g_req_handler->SetRPCServer(g_rpc_service.get());
g_notify_handler.SetRPCServer(rpc_service.get());
g_notify_handler.SetScope(scope);
g_notify_handler.SetDevCtx(&dev_ctx);
g_notify_handler.SetProgram(&empty_program);
g_notify_handler.SetExecutor(&executor);
g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
g_rpc_service->RegisterRPC(distributed::RequestNotifyHandler,
g_notify_handler.get());
distributed::BarrierMonitor::Init(1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(distributed::kRequestSend);
g_rpc_service->WaitBarrier(distributed::kRequestSend);
barrier->WaitServerWeakup();
barrier->ServerWeakup();
LOG(INFO) << "got nccl id and stop server...";
barrier->Stop();
g_rpc_service->ShutDown();
server_thread.join();
}
......@@ -74,6 +87,10 @@ void StartServer() {
TEST(SendNcclId, RPCServer) {
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_notify_handler.reset(new distributed::RequestNotifyHandler(
distributed::DistributedMode::kSync, -1));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
......@@ -104,4 +121,5 @@ TEST(SendNcclId, RPCServer) {
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
g_notify_handler.reset(nullptr);
}
......@@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
current_id=0,
role=role_maker.Role.WORKER
if training_role == "TRAINER" else role_maker.Role.SERVER,
worker_num=2,
worker_num=1,
server_endpoints=["127.0.0.1:6002"])
if training_role == "TRAINER":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册