diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index e2b09be5a9dfff0111ab80d89bdd76b99517738f..0621fa938c9f854ef1c906620f3e474c375efb8a 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" @@ -95,22 +93,39 @@ class CGenNCCLIdOp : public framework::OperatorBase { new RPCSERVER_T(endpoint, 1)); rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - rpc_h.SetRPCServer(rpc_service.get()); + distributed::RequestNotifyHandler notify_h( + distributed::DistributedMode::kSync, -1); + + rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); + rpc_service->RegisterRPC(distributed::kRequestNotify, ¬ify_h); framework::ProgramDesc empty_program; framework::Executor executor(dev_ctx.GetPlace()); + + rpc_h.SetRPCServer(rpc_service.get()); rpc_h.SetScope(scope); rpc_h.SetDevCtx(&dev_ctx); rpc_h.SetProgram(&empty_program); rpc_h.SetExecutor(&executor); + notify_h.SetRPCServer(rpc_service.get()); + notify_h.SetScope(scope); + notify_h.SetDevCtx(&dev_ctx); + notify_h.SetProgram(&empty_program); + notify_h.SetExecutor(&executor); + + distributed::BarrierMonitor::Init(1); + auto* barrier = distributed::BarrierMonitor::GetInstance(); + barrier->Reset(1, distributed::BarrierType::kSendBarrier); + std::thread server_thread( std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); - rpc_service->SetCond(distributed::kRequestSend); VLOG(3) << "start getting nccl id from trainer 0..."; - rpc_service->WaitBarrier(distributed::kRequestSend); + barrier->WaitServerWeakup(); + barrier->ServerWeakup(); VLOG(3) << "got nccl id and stop server..."; + barrier->Stop(); rpc_service->ShutDown(); VLOG(3) << "rpc server stopped"; server_thread.join(); @@ -123,7 +138,6 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces."); AddComment(R"DOC( CGenNCCLId operator - For trainer 0: generate a new UniqueId and send it to all the other trainers. For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. )DOC"); diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 5aa91733fe3ed1bfc51b47b331488ce2211be2fb..76b15733c51a89af716b92ff13cdef20349c9015 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -15,6 +15,8 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) +cc_library(barrier_monitor SRCS barrier_monitor.cc DEPS enforce simple_threadpool trainer_desc_proto) + # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") if(WITH_GRPC) @@ -26,7 +28,7 @@ if(WITH_GRPC) collective_client.cc collective_server.cc ${GRPC_SRCS} PROTO send_recv.proto - DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor) + DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor barrier_monitor) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) diff --git a/paddle/fluid/operators/distributed/barrier_monitor.cc b/paddle/fluid/operators/distributed/barrier_monitor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6d82f5d8c3daea9b629c1937bfcbc5159cda461 --- /dev/null +++ b/paddle/fluid/operators/distributed/barrier_monitor.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/barrier_monitor.h" +#include + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include // NOLINT + +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace distributed { +bool BarrierMonitor::IncreaseBarrier(const int worker_id, + const std::string &barrier) { + release_ = false; + + if (barrier == BATCH_BARRIER_MESSAGE) { + VLOG(4) << "BarrierMonitor send queue recv trainer: " << worker_id; + send_barrier_queue->Push(worker_id); + } else if (barrier == FETCH_BARRIER_MESSAGE) { + VLOG(4) << "BarrierMonitor recv queue recv trainer: " << worker_id; + recv_barrier_queue->Push(worker_id); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "unknown Message status %s, only " + "BATCH_BARRIER_MESSAGE/FETCH_BARRIER_MESSAGE", + barrier)); + } + return Wait(); +} + +void BarrierMonitor::DecreaseWorker() { + std::unique_lock lck(mutex_); + workers_--; + VLOG(1) << "decrement worker num to " << workers_; +} + +void BarrierMonitor::Reset(int workers, BarrierType type) { + std::unique_lock lk(server_mutex_); + + workers_ = workers; + barrier_type = type; + + send_barrier_queue->Clear(); + recv_barrier_queue->Clear(); + VLOG(2) << "reset monitor workers: " << workers_ << " type: " << barrier_type; +} + +void BarrierMonitor::Monitor() { + while (!IsReady() && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + VLOG(3) << "sync at first time, wait all trainer ready"; + } + + while (running_) { + int timer = 0; + + if (IsReady()) { + Swap(true); + } else { + VLOG(4) << "running timer: " << timer << " barrier: " << barrier_type + << " sendQ:" << send_barrier_queue->Size() + << " recvQ: " << recv_barrier_queue->Size(); + + timer++; + if (max_wait_ms == -1 || timer < max_wait_ms) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + VLOG(1) << "time out of " << max_wait_ms + << ", need barreir: " << barrier_type << " retry"; + Swap(false); + } + } + } +} + +bool BarrierMonitor::IsReady() { + if (barrier_type == BarrierType::kSendBarrier) { + return static_cast(send_barrier_queue->Size()) == workers_; + } else { + return static_cast(recv_barrier_queue->Size()) == workers_; + } +} + +void BarrierMonitor::Swap(bool is_valid) { + std::unique_lock lck(mutex_); + + valid_ = is_valid; + release_ = true; + + if (barrier_type == BarrierType::kSendBarrier) { + barrier_type = BarrierType::kRecvBarrier; + send_barrier_queue->Clear(); + VLOG(4) << "barrier monitor server clean up queue and barrier"; + ServerWeakup(); + VLOG(4) << "barrier monitor server weak up sync to do"; + WaitServerWeakup(); + VLOG(4) << "barrier monitor server weak up sync done"; + + } else { + barrier_type = BarrierType::kSendBarrier; + recv_barrier_queue->Clear(); + VLOG(4) << "barrier monitor server switch to send barrier"; + } + + worker_cv_.notify_all(); +} + +void BarrierMonitor::Stop() { + valid_ = true; + release_ = true; + running_ = false; + + barrier_type = BarrierType::kRecvBarrier; + send_barrier_queue->Clear(); + recv_barrier_queue->Clear(); + + worker_cv_.notify_all(); + server_cv_.notify_all(); + + if (monitor_thread_) monitor_thread_->join(); + monitor_thread_ = nullptr; +} + +bool BarrierMonitor::Wait() { + std::unique_lock lk(mutex_); + worker_cv_.wait(lk, [this] { return (release_); }); + return valid_; +} + +void BarrierMonitor::WaitServerWeakup() { + std::unique_lock lk(server_mutex_); + server_cv_.wait(lk); +} + +void BarrierMonitor::ServerWeakup() { server_cv_.notify_all(); } + +std::once_flag BarrierMonitor::init_flag_; +std::unique_ptr BarrierMonitor::monitor_(nullptr); + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/barrier_monitor.h b/paddle/fluid/operators/distributed/barrier_monitor.h new file mode 100644 index 0000000000000000000000000000000000000000..f9556d7720f7a7ebcadcc1f86ad6051786777041 --- /dev/null +++ b/paddle/fluid/operators/distributed/barrier_monitor.h @@ -0,0 +1,186 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include // NOLINT + +#include + +#include "paddle/fluid/operators/distributed/rpc_server.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace distributed { + +enum BarrierType { kSendBarrier, kRecvBarrier }; + +constexpr int64_t kMaxWaitMS = 120000; + +template +class BlockingQueueForBarrier { + public: + explicit BlockingQueueForBarrier(size_t capacity) : capacity_(capacity) { + PADDLE_ENFORCE_GT(capacity_, 0, + platform::errors::InvalidArgument( + "The capacity must be greater than 0.")); + } + + bool Push(const T &elem) { + { + std::unique_lock lock(mutex_); + worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.push_back(elem); + } + worker_cv_.notify_one(); + return true; + } + + bool Push(T &&elem) { + { + std::unique_lock lock(mutex_); + worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.emplace_back(std::move(elem)); + } + worker_cv_.notify_one(); + return true; + } + + T Pop() { + std::unique_lock lock(mutex_); + worker_cv_.wait(lock, [=] { return !queue_.empty(); }); + T rc(std::move(queue_.front())); + queue_.pop_front(); + worker_cv_.notify_one(); + return rc; + } + + size_t Cap() const { + std::lock_guard lock(mutex_); + return capacity_; + } + + size_t Size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + void Clear() { + std::lock_guard lock(mutex_); + std::deque().swap(queue_); + } + + private: + const size_t capacity_; + std::deque queue_; + + mutable std::mutex mutex_; + std::condition_variable worker_cv_; +}; + +class BarrierMonitor { + public: + explicit BarrierMonitor(int workers) + : BarrierMonitor(workers, BarrierType::kRecvBarrier, kMaxWaitMS) {} + + explicit BarrierMonitor(int workers, BarrierType type, int64_t max_wait_times) + : workers_(workers), barrier_type(type), max_wait_ms(max_wait_times) { + PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument( + "trainers must have one or more")); + + send_barrier_queue = + std::make_shared>(workers); + recv_barrier_queue = + std::make_shared>(workers); + + running_ = true; + monitor_thread_.reset( + new std::thread(std::bind(&BarrierMonitor::Monitor, this))); + } + + static BarrierMonitor *Init(int workers) { + InitImpl(workers); + return GetInstance(); + } + + static BarrierMonitor *GetInstance() { return monitor_.get(); } + + bool IncreaseBarrier(const int worker_id, const std::string &barrier); + + void DecreaseWorker(); + + int GetWorkerNum() { return workers_; } + + void Monitor(); + + void Swap(bool is_valid); + + void Stop(); + + bool IsReady(); + + bool Wait(); + + void WaitServerWeakup(); + + void ServerWeakup(); + + void WorkerWeakup(); + + void Reset(int workers, BarrierType type); + + private: + // Init is called by GetInstance. + static void InitImpl(int workers) { + monitor_.reset(new BarrierMonitor(workers)); + } + + static std::once_flag init_flag_; + static std::unique_ptr monitor_; + + int workers_; + bool running_ = false; + bool valid_ = false; + bool release_ = false; + + std::condition_variable worker_cv_; + std::condition_variable server_cv_; + + std::mutex server_mutex_; + std::mutex mutex_; + + BarrierType barrier_type; + int64_t max_wait_ms; + std::unique_ptr monitor_thread_{nullptr}; + std::shared_ptr> send_barrier_queue; + std::shared_ptr> recv_barrier_queue; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 4041f9091003e92f109606f309e7f02452f3ed69..ca93f7eb958cde66b933612f05bdfc2965cd2a75 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -306,52 +306,19 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); - - BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = kBatchBarrierRPC; - VarHandlePtr h( - new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_varname(BATCH_BARRIER_MESSAGE); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - + platform::CPUDeviceContext ctx; + auto* scope = new framework::Scope(); + auto h = AsyncDistributeNotify(ep, ctx, *scope, BATCH_BARRIER_MESSAGE); + delete scope; return h; } VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); - FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - const std::string method = kFetchBarrierRPC; - VarHandlePtr h( - new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_varname(FETCH_BARRIER_MESSAGE); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - + platform::CPUDeviceContext ctx; + auto* scope = new framework::Scope(); + auto h = AsyncDistributeNotify(ep, ctx, *scope, FETCH_BARRIER_MESSAGE); + delete scope; return h; } @@ -384,27 +351,10 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); - - BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = kSendCompleteRPC; - VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_trainer_id(trainer_id_); - req.set_varname(COMPLETE_MESSAGE); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - + platform::CPUDeviceContext ctx; + auto* scope = new framework::Scope(); + auto h = AsyncDistributeNotify(ep, ctx, *scope, COMPLETE_MESSAGE); + delete scope; return h; } @@ -454,10 +404,21 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( s->Prepare(h, time_out); framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { - auto* var = p_scope->FindVar(var_name_val); + ::grpc::ByteBuffer buf; - ::grpc::ByteBuffer req; - SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); + if (var_name_val == BATCH_BARRIER_MESSAGE || + var_name_val == FETCH_BARRIER_MESSAGE || + var_name_val == COMPLETE_MESSAGE) { + // prepare input + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + req.set_out_varname(var_name_val); + req.set_trainer_id(trainer_id_); + RequestToByteBuffer(req, &buf); + } else { + auto* var = p_scope->FindVar(var_name_val); + SerializeToByteBuffer(var_name_val, var, *p_ctx, &buf, "", trainer_id_); + } VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; @@ -467,7 +428,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( platform::RecordRPCEvent record_event(method); auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req, + s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", buf, &cq_); call->StartCall(); call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 0205bab0504d75df4e2b8bf15326a8aec9127544..5871bd14fc8033ea50c829e99b40fc2322033b16 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -28,6 +28,7 @@ #include "paddle/fluid/string/split.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" +#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" namespace paddle { @@ -38,161 +39,130 @@ namespace distributed { // to directory specified. constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; -bool RequestSendHandler::Handle(const std::string& varname, - framework::Scope* scope, - framework::Variable* invar, - framework::Variable** outvar, +bool RequestSendHandler::Handle(const std::string &varname, + framework::Scope *scope, + framework::Variable *invar, + framework::Variable **outvar, const int trainer_id, - const std::string& out_var_name, - const std::string& table_name) { + const std::string &out_var_name, + const std::string &table_name) { VLOG(4) << "RequestSendHandler:" << varname; - // Sync - if (varname == BATCH_BARRIER_MESSAGE) { - VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; - rpc_server_->IncreaseBatchBarrier(kRequestSend); - } else if (varname == COMPLETE_MESSAGE) { - VLOG(3) << "sync: recv complete message"; + if (invar == nullptr) { + PADDLE_THROW(platform::errors::NotFound( + "sync: Can not find server side var: %s", varname)); + return false; + } - if (HeartBeatMonitor::GetInstance() != nullptr) { - HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); - } + if (distributed_mode_ == DistributedMode::kSync) { + return true; + } - rpc_server_->Complete(); - } else { - // Async - if (distributed_mode_ != DistributedMode::kSync) { - VLOG(3) << "async process var: " << varname; - if (varname == BATCH_BARRIER_MESSAGE) { - PADDLE_THROW( - "async mode should not recv BATCH_BARRIER_MESSAGE or " - "COMPLETE_MESSAGE"); - } - HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); + HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); - std::string run_varname = varname; + std::string run_varname = varname; + string::Piece part_piece("@PIECE"); + string::Piece var_name_piece = string::Piece(varname); - string::Piece part_piece("@PIECE"); - string::Piece var_name_piece = string::Piece(varname); + if (string::Contains(var_name_piece, part_piece)) { + auto varname_splits = paddle::string::Split(varname, '@'); + run_varname = varname_splits[0]; + scope->Rename(varname, run_varname); + } - if (string::Contains(var_name_piece, part_piece)) { - auto varname_splits = paddle::string::Split(varname, '@'); - PADDLE_ENFORCE_EQ(varname_splits.size(), 3); - run_varname = varname_splits[0]; - scope->Rename(varname, run_varname); - } + if (distributed_mode_ == DistributedMode::kGeo && + AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) { + auto &grad_slr = + scope->FindVar(run_varname)->Get(); + AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, + grad_slr.rows()); + } - if (distributed_mode_ == DistributedMode::kGeo && - AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) { - auto& grad_slr = - scope->FindVar(run_varname)->Get(); - AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, - grad_slr.rows()); - } - executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), - scope); + executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), + scope); - return true; - } else { // sync - rpc_server_->WaitCond(kRequestSend); - VLOG(3) << "sync: processing received var: " << varname; - PADDLE_ENFORCE_NOT_NULL( - invar, platform::errors::NotFound( - "sync: Can not find server side var %s.", varname)); - } - } return true; } -bool RequestGetHandler::Handle(const std::string& varname, - framework::Scope* scope, - framework::Variable* invar, - framework::Variable** outvar, +bool RequestGetHandler::Handle(const std::string &varname, + framework::Scope *scope, + framework::Variable *invar, + framework::Variable **outvar, const int trainer_id, - const std::string& out_var_name, - const std::string& table_name) { + const std::string &out_var_name, + const std::string &table_name) { VLOG(3) << "RequestGetHandler:" << varname << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id << " table_name: " << table_name; if (distributed_mode_ == DistributedMode::kSync) { - if (varname == FETCH_BARRIER_MESSAGE) { - VLOG(3) << "sync: recv fetch barrier message"; - rpc_server_->IncreaseBatchBarrier(kRequestGet); - } else { - rpc_server_->WaitCond(kRequestGet); - *outvar = scope_->FindVar(varname); - } + *outvar = scope_->FindVar(varname); } else { - if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { - if (enable_dc_asgd_) { - // NOTE: the format is determined by distribute_transpiler.py - std::string param_bak_name = - string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); - VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; - auto var = scope_->FindVar(varname); - auto t_orig = var->Get(); - auto param_bak = scope_->Var(param_bak_name); - auto t = param_bak->GetMutable(); - t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); - VLOG(3) << "copying " << varname << " to " << param_bak_name; - framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); - } - VLOG(1) << "Table name empty? " << table_name.empty(); - if (distributed_mode_ == DistributedMode::kGeo) { - VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist " - << AsyncSparseParamUpdateRecorder::GetInstance()->HasParam( - varname); - } - if (distributed_mode_ == DistributedMode::kGeo && - AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && - !table_name.empty()) { - std::vector updated_rows; - AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( - varname, trainer_id, &updated_rows); - if (VLOG_IS_ON(3)) { - std::ostringstream sstream; - sstream << "["; - for (auto& row_id : updated_rows) { - sstream << row_id << ", "; - } - sstream << "]"; - VLOG(3) << "updated_rows size: " << updated_rows.size() << " " - << sstream.str(); - } - auto& origin_tensor = - scope_->FindVar(varname)->Get(); - auto* origin_tensor_data = origin_tensor.data(); - auto& dims = origin_tensor.dims(); - *outvar = scope->Var(); - auto* out_slr = (*outvar)->GetMutable(); - out_slr->set_rows(updated_rows); - out_slr->set_height(dims[0]); - auto out_dims = framework::make_ddim( - {static_cast(updated_rows.size()), dims[1]}); - auto* data = out_slr->mutable_value()->mutable_data( - out_dims, origin_tensor.place()); - auto width = dims[1]; - for (size_t i = 0; i < updated_rows.size(); ++i) { - PADDLE_ENFORCE_LT(updated_rows[i], dims[0]); - memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, - sizeof(float) * width); + if (enable_dc_asgd_) { + // NOTE: the format is determined by distribute_transpiler.py + std::string param_bak_name = + string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); + VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; + auto var = scope_->FindVar(varname); + auto t_orig = var->Get(); + auto param_bak = scope_->Var(param_bak_name); + auto t = param_bak->GetMutable(); + t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); + VLOG(3) << "copying " << varname << " to " << param_bak_name; + framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); + } + + if (distributed_mode_ == DistributedMode::kGeo && + AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && + !table_name.empty()) { + std::vector updated_rows; + AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( + varname, trainer_id, &updated_rows); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto &row_id : updated_rows) { + sstream << row_id << ", "; } - } else { - *outvar = scope_->FindVar(varname); + sstream << "]"; + VLOG(3) << "updated_rows size: " << updated_rows.size() << " " + << sstream.str(); + } + auto &origin_tensor = + scope_->FindVar(varname)->Get(); + auto *origin_tensor_data = origin_tensor.data(); + auto &dims = origin_tensor.dims(); + *outvar = scope->Var(); + auto *out_slr = (*outvar)->GetMutable(); + out_slr->set_rows(updated_rows); + out_slr->set_height(dims[0]); + auto out_dims = framework::make_ddim( + {static_cast(updated_rows.size()), dims[1]}); + auto *data = out_slr->mutable_value()->mutable_data( + out_dims, origin_tensor.place()); + auto width = dims[1]; + for (size_t i = 0; i < updated_rows.size(); ++i) { + PADDLE_ENFORCE_LT(updated_rows[i], dims[0], + platform::errors::OutOfRange( + "expected >= 0 and < %ld, but got %ld.", dims[0], + updated_rows[i])); + memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, + sizeof(float) * width); } + } else { + *outvar = scope_->FindVar(varname); } } return true; } -bool RequestGetNoBarrierHandler::Handle(const std::string& varname, - framework::Scope* scope, - framework::Variable* invar, - framework::Variable** outvar, +bool RequestGetNoBarrierHandler::Handle(const std::string &varname, + framework::Scope *scope, + framework::Variable *invar, + framework::Variable **outvar, const int trainer_id, - const std::string& out_var_name, - const std::string& table_name) { + const std::string &out_var_name, + const std::string &table_name) { VLOG(4) << "RequestGetNoBarrierHandler:" << varname << " out_var_name: " << out_var_name; @@ -207,18 +177,19 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname, *outvar = scope_->FindVar(var_name_piece.ToString()); return true; } else { - PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE); + PADDLE_THROW(platform::errors::InvalidArgument( + "GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE)); } return true; } -bool RequestPrefetchHandler::Handle(const std::string& varname, - framework::Scope* scope, - framework::Variable* invar, - framework::Variable** outvar, +bool RequestPrefetchHandler::Handle(const std::string &varname, + framework::Scope *scope, + framework::Variable *invar, + framework::Variable **outvar, const int trainer_id, - const std::string& out_var_name, - const std::string& table_name) { + const std::string &out_var_name, + const std::string &table_name) { VLOG(4) << "RequestPrefetchHandler " << varname; if (table_name.empty()) { @@ -236,19 +207,20 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, return true; } -bool RequestCheckpointHandler::Handle(const std::string& varname, - framework::Scope* scope, - framework::Variable* invar, - framework::Variable** outvar, +bool RequestCheckpointHandler::Handle(const std::string &varname, + framework::Scope *scope, + framework::Variable *invar, + framework::Variable **outvar, const int trainer_id, - const std::string& out_var_name, - const std::string& table_name) { - PADDLE_ENFORCE( - checkpoint_notify_id != -1, - "when checkpoint_notify_id = -1, there should be no RPC invoke."); + const std::string &out_var_name, + const std::string &table_name) { + PADDLE_ENFORCE_NE( + checkpoint_notify_id, -1, + platform::errors::Unavailable( + "when checkpoint_notify_id = -1, there should be no RPC invoke.")); // TODO(tangwei12): find out why scope will be error. - auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); + auto *lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " @@ -257,33 +229,56 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, return true; } -bool RequestNotifyHandler::Handle(const std::string& varname, - framework::Scope* scope, - framework::Variable* invar, - framework::Variable** outvar, +bool RequestNotifyHandler::Handle(const std::string &varname, + framework::Scope *scope, + framework::Variable *invar, + framework::Variable **outvar, const int trainer_id, - const std::string& out_var_name, - const std::string& table_name) { - VLOG(4) << "RequestNotifyHandler: " << varname; + const std::string &out_var_name, + const std::string &table_name) { VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id; string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER); + string::Piece batch_piece(BATCH_BARRIER_MESSAGE); + string::Piece fetch_piece(FETCH_BARRIER_MESSAGE); + string::Piece complete_piece(COMPLETE_MESSAGE); + string::Piece var_name_piece = string::Piece(varname); - if (string::Contains(var_name_piece, decay_piece)) { + + if (string::Contains(var_name_piece, batch_piece)) { + return BarrierMonitor::GetInstance()->IncreaseBarrier( + trainer_id, BATCH_BARRIER_MESSAGE); + } else if (string::Contains(var_name_piece, fetch_piece)) { + return BarrierMonitor::GetInstance()->IncreaseBarrier( + trainer_id, FETCH_BARRIER_MESSAGE); + } else if (string::Contains(var_name_piece, complete_piece)) { + if (HeartBeatMonitor::GetInstance() != nullptr) { + HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); + } + rpc_server_->Complete(); + BarrierMonitor::GetInstance()->DecreaseWorker(); + return true; + } else if (string::Contains(var_name_piece, decay_piece)) { VLOG(3) << "LearningRate Decay Counter Update"; PADDLE_ENFORCE_NE( lr_decay_block_id, -1, - "when lr_decay_block_id = -1, there should be no RPC invoke."); - auto* origin_var = scope_->FindVar(varname); + platform::errors::InvalidArgument( + "when lr_decay_block_id = -1, there should be no RPC invoke.")); + auto *origin_var = scope_->FindVar(varname); auto origin_var_tensor = origin_var->Get(); - auto* send_var = scope->FindVar(varname); + auto *send_var = scope->FindVar(varname); auto send_var_tensor = send_var->Get(); - int64_t* origin_value = + int64_t *origin_value = origin_var_tensor.mutable_data(origin_var_tensor.place()); - int64_t* send_value = + int64_t *send_value = send_var_tensor.mutable_data(send_var_tensor.place()); origin_value[0] += send_value[0]; executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); + + return true; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "unkown varname %s with RequestNotifyHandler", varname)); } return true; } diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index d36a433db7dda89b5a9edb6fb8db8552ecce7854..bc17c84645116df7868107a6acf3de620dd9f798 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" @@ -119,6 +117,7 @@ void StartServer(const std::string& rpc_name) { g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); distributed::HeartBeatMonitor::Init(2, true, "w@grad"); + distributed::BarrierMonitor::Init(2); g_req_handler->SetRPCServer(g_rpc_service.get()); @@ -164,6 +163,9 @@ TEST(PREFETCH, CPU) { } } + auto* barrier = distributed::BarrierMonitor::GetInstance(); + barrier->Stop(); + g_rpc_service->ShutDown(); server_thread.join(); LOG(INFO) << "begin reset"; @@ -174,20 +176,24 @@ TEST(PREFETCH, CPU) { TEST(COMPLETE, CPU) { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); - g_req_handler.reset( - new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); + g_req_handler.reset(new distributed::RequestNotifyHandler( + distributed::DistributedMode::kSync, -1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); + distributed::RPCClient* client = distributed::RPCClient::GetInstance(0); PADDLE_ENFORCE(client != nullptr); - std::thread server_thread(StartServer, distributed::kRequestSend); + std::thread server_thread(StartServer, distributed::kRequestNotify); g_rpc_service->WaitServerReady(); int port = g_rpc_service->GetSelectedPort(); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); client->AsyncSendComplete(ep); client->Wait(); - EXPECT_EQ(g_rpc_service->GetClientNum(), 1); + auto* barrier = distributed::BarrierMonitor::GetInstance(); + EXPECT_EQ(barrier->GetWorkerNum(), 1); + + barrier->Stop(); g_rpc_service->ShutDown(); server_thread.join(); diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 79f14d75d279d0ae1a68bf857ab9f46d6b71c42f..244d3ece48ecc201465a6badeb5cd44bbf71f4a8 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc index e63f882478351cde16bde969b86e020181d6d4e5..cf8322905297156ba5e36c5b21e009739daa194f 100644 --- a/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -30,16 +28,16 @@ namespace operators { class GenNCCLIdOp : public framework::OperatorBase { public: - GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + GenNCCLIdOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); // put nccl id in CPUPlace - auto& dev_ctx = *pool.Get(platform::CPUPlace()); + auto &dev_ctx = *pool.Get(platform::CPUPlace()); int trainer_id = Attr("trainer_id"); std::vector trainers = @@ -55,7 +53,7 @@ class GenNCCLIdOp : public framework::OperatorBase { std::string endpoint = trainers[trainer_id]; - framework::Scope& local_scope = scope.NewScope(); + framework::Scope &local_scope = scope.NewScope(); int nccl_comm_num = Attr("nccl_comm_num"); int use_hierarchical_allreduce = Attr("use_hierarchical_allreduce"); @@ -171,10 +169,10 @@ class GenNCCLIdOp : public framework::OperatorBase { } private: - void GenerateAndSend(framework::Scope* scope, - const platform::DeviceContext& dev_ctx, - const std::string& nccl_id_name, - const std::vector& endpoint_list) const { + void GenerateAndSend(framework::Scope *scope, + const platform::DeviceContext &dev_ctx, + const std::string &nccl_id_name, + const std::vector &endpoint_list) const { auto var = scope->FindVar(nccl_id_name); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::NotFound("Variable with name %s is not found", @@ -182,76 +180,96 @@ class GenNCCLIdOp : public framework::OperatorBase { auto id = var->GetMutable(); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id)); - distributed::RPCClient* client = + distributed::RPCClient *client = distributed::RPCClient::GetInstance(0); - for (auto& ep : endpoint_list) { + for (auto &ep : endpoint_list) { VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep; client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name); } client->Wait(); - for (auto& ep : endpoint_list) { + for (auto &ep : endpoint_list) { client->AsyncSendBatchBarrier(ep); } client->Wait(); VLOG(3) << "sending completed..."; } - void GetIdByServer(const std::string& endpoint, framework::Scope* scope, - const platform::DeviceContext& dev_ctx, int nccl_comm_num, + void GetIdByServer(const std::string &endpoint, framework::Scope *scope, + const platform::DeviceContext &dev_ctx, int nccl_comm_num, bool use_hierarchical_allreduce, int trainer_id, int inter_trainer_id, int exter_trainer_id) const { // std::string endpoint = Attr("endpoint"); // NOTE: Can not use unique_ptr here because the default // deleter will call GRPC Server's base class's dtor and // that will cause a wired crash. - distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync); + std::unique_ptr rpc_service( new RPCSERVER_T(endpoint, 1)); + distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync); + + distributed::RequestNotifyHandler notify_h( + distributed::DistributedMode::kSync, -1); + rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - rpc_h.SetRPCServer(rpc_service.get()); + rpc_service->RegisterRPC(distributed::kRequestNotify, ¬ify_h); framework::ProgramDesc empty_program; framework::Executor executor(dev_ctx.GetPlace()); + + rpc_h.SetRPCServer(rpc_service.get()); rpc_h.SetScope(scope); rpc_h.SetDevCtx(&dev_ctx); rpc_h.SetProgram(&empty_program); rpc_h.SetExecutor(&executor); + notify_h.SetRPCServer(rpc_service.get()); + notify_h.SetScope(scope); + notify_h.SetDevCtx(&dev_ctx); + notify_h.SetProgram(&empty_program); + notify_h.SetExecutor(&executor); + + distributed::BarrierMonitor::Init(1); + auto *barrier = distributed::BarrierMonitor::GetInstance(); + barrier->Reset(1, distributed::BarrierType::kSendBarrier); + std::thread server_thread( std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); for (int i = 0; i < nccl_comm_num; i++) { - rpc_service->SetCond(distributed::kRequestSend); + barrier->WaitServerWeakup(); + barrier->Reset(1, distributed::BarrierType::kSendBarrier); + barrier->ServerWeakup(); + VLOG(3) << "trainer_id:" << trainer_id << " start getting nccl id from trainer 0, nccl_comm_no:" << i; - rpc_service->WaitBarrier(distributed::kRequestSend); - rpc_service->ResetBarrierCounter(); } if (use_hierarchical_allreduce) { if (inter_trainer_id > 0) { for (int i = 0; i < nccl_comm_num; i++) { - rpc_service->SetCond(distributed::kRequestSend); + barrier->WaitServerWeakup(); + barrier->Reset(1, distributed::BarrierType::kSendBarrier); + barrier->ServerWeakup(); + VLOG(3) << "trainer_id:" << trainer_id << ", inter_trainer_id:" << inter_trainer_id << " start getting nccl id from inter_trainer:" << i; - rpc_service->WaitBarrier(distributed::kRequestSend); - rpc_service->ResetBarrierCounter(); } } if (exter_trainer_id > 0) { for (int i = 0; i < nccl_comm_num; i++) { - rpc_service->SetCond(distributed::kRequestSend); + barrier->WaitServerWeakup(); + barrier->Reset(1, distributed::BarrierType::kSendBarrier); + barrier->ServerWeakup(); + VLOG(3) << "trainer_id:" << trainer_id << ", exter_trainer_id:" << exter_trainer_id << " start getting nccl id from exter_trainer 0, nccl_comm_no:" << i; - rpc_service->WaitBarrier(distributed::kRequestSend); - rpc_service->ResetBarrierCounter(); } } } @@ -260,6 +278,7 @@ class GenNCCLIdOp : public framework::OperatorBase { << ", inter_trainer_id:" << inter_trainer_id << ", exter_trainer_id:" << exter_trainer_id << " got nccl id and stop server..."; + barrier->Stop(); rpc_service->ShutDown(); VLOG(3) << "rpc server stopped"; server_thread.join(); @@ -272,7 +291,6 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); AddComment(R"DOC( GenNCCLId operator - For trainer 0: generate a new UniqueId and send it to all the other trainers. For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. )DOC"); diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index d40df6f9de0c1e22ea892993d66a2cdfa808b1c7..c8c0316e74739622f46cb577ae051fc88dd39bb7 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" +#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" @@ -38,10 +36,13 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch"); namespace paddle { namespace operators { +volatile sig_atomic_t gSignalStatus; + void RunServer(std::shared_ptr service) { service->StartServer(); VLOG(4) << "RunServer thread end"; } + static void split(const std::string &str, char sep, std::vector *pieces) { pieces->clear(); @@ -126,6 +127,7 @@ void ListenAndServOp::RunSyncLoop( for (size_t i = 1; i < program->Size(); ++i) { optimize_blocks_list.push_back(i); } + auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list); // Insert placeholder for block0 which holds current op itself, // NOTE the first block in `optimize_prepared` should never be ran. @@ -135,21 +137,15 @@ void ListenAndServOp::RunSyncLoop( // Trainers will get all parameters from pserver in the // startup program, so we will wait RequestGet first - rpc_service_->SetCond(distributed::kRequestGet); - rpc_service_->WaitBarrier(distributed::kRequestGet); - rpc_service_->ResetBarrierCounter(); + auto *barrier = distributed::BarrierMonitor::GetInstance(); while (true) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - VLOG(3) << "wait all clients to send gradient"; - rpc_service_->SetCond(distributed::kRequestSend); - VLOG(3) << "wait all clients to send send_barrier"; - rpc_service_->WaitBarrier(distributed::kRequestSend); + barrier->WaitServerWeakup(); - if (rpc_service_->IsExit()) { + if (gSignalStatus != 0) { LOG(WARNING) << "get exit!rpc_processor break!"; - rpc_service_->SetCond(distributed::kRequestGet); break; } @@ -180,12 +176,8 @@ void ListenAndServOp::RunSyncLoop( VLOG(3) << "ResetReceivedVars"; ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); - VLOG(3) << "wait all clients to get parameters back"; - rpc_service_->SetCond(distributed::kRequestGet); - VLOG(3) << "wait all clients to send fetch_barrier"; - rpc_service_->WaitBarrier(distributed::kRequestGet); - VLOG(3) << "ResetBarrierCounter"; - rpc_service_->ResetBarrierCounter(); + barrier->ServerWeakup(); + VLOG(3) << "kRecvBarrier to push params to trainers"; } // while(true) } @@ -281,7 +273,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); while (true) { - if (rpc_service_->IsExit()) { + if (gSignalStatus != 0) { VLOG(4) << "get exit!rpc_processor break!"; break; } @@ -391,7 +383,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, request_get_no_barrier_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestNotify, - request_notify_handler_.get(), rpc_send_thread_num); + request_notify_handler_.get(), fan_in * 2); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -440,6 +432,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, std::unordered_map> prefetch_var_name_to_prepared_ctx; + for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) { auto block_id = prefetch_block_id_list[i]; auto prefetch_var_name = block_id_to_prefetch_var_name[block_id]; @@ -448,8 +441,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // parse attr of kSparseGradToParam sparse_grad_name -> param_name std::unordered_map sparse_grad_name_to_param_name; + auto sparse_grad_name_to_param_name_str = Attr>(kSparseGradToParam); + for (const auto &sparse_grad_name_and_param_name : sparse_grad_name_to_param_name_str) { std::vector pieces; @@ -477,17 +472,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, signal(SIGINT, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit); + distributed::BarrierMonitor::Init(fan_in); + if (distributed_mode == distributed::DistributedMode::kSync) { // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); VLOG(3) << "wait server thread to become ready..."; rpc_service_->WaitServerReady(); - - CacheVarsType(inputs, recv_scope); - // Write to a file of server selected port for python use. SavePort(); + CacheVarsType(inputs, recv_scope); + RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, prefetch_block_id_list, checkpoint_block_id); } else { @@ -574,9 +570,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { void SignalHandler::StopAndExit(int signal_num) { // Do not use VLOG here for the device for printing maybe already released. // exit will release interal allocated resoureces. - auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); - remove(file_path.c_str()); - exit(0); + distributed::BarrierMonitor::GetInstance()->Stop(); + gSignalStatus = signal_num; } } // namespace operators diff --git a/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc b/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc index b65621a0886b02fd8d3c029c979348469014cadc..3f5fdadc22342bf17f54d86e39bdc5114915c001 100644 --- a/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc +++ b/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" @@ -42,6 +40,7 @@ namespace string = paddle::string; std::unique_ptr g_rpc_service; std::unique_ptr g_req_handler; +std::unique_ptr g_notify_handler; void StartServer() { f::Scope scope; @@ -52,21 +51,35 @@ void StartServer() { f::ProgramDesc empty_program; f::Executor executor(dev_ctx.GetPlace()); + g_req_handler->SetScope(&scope); g_req_handler->SetDevCtx(&dev_ctx); g_req_handler->SetProgram(&empty_program); g_req_handler->SetExecutor(&executor); + g_req_handler->SetRPCServer(g_rpc_service.get()); + + g_notify_handler.SetRPCServer(rpc_service.get()); + g_notify_handler.SetScope(scope); + g_notify_handler.SetDevCtx(&dev_ctx); + g_notify_handler.SetProgram(&empty_program); + g_notify_handler.SetExecutor(&executor); g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get()); - g_req_handler->SetRPCServer(g_rpc_service.get()); + g_rpc_service->RegisterRPC(distributed::RequestNotifyHandler, + g_notify_handler.get()); + + distributed::BarrierMonitor::Init(1); + auto* barrier = distributed::BarrierMonitor::GetInstance(); + barrier->Reset(1, distributed::BarrierType::kSendBarrier); std::thread server_thread( std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); - g_rpc_service->SetCond(distributed::kRequestSend); - g_rpc_service->WaitBarrier(distributed::kRequestSend); + barrier->WaitServerWeakup(); + barrier->ServerWeakup(); LOG(INFO) << "got nccl id and stop server..."; + barrier->Stop(); g_rpc_service->ShutDown(); server_thread.join(); } @@ -74,6 +87,10 @@ void StartServer() { TEST(SendNcclId, RPCServer) { g_req_handler.reset( new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); + + g_notify_handler.reset(new distributed::RequestNotifyHandler( + distributed::DistributedMode::kSync, -1)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); std::thread server_thread(StartServer); @@ -104,4 +121,5 @@ TEST(SendNcclId, RPCServer) { server_thread.join(); g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); + g_notify_handler.reset(nullptr); } diff --git a/python/paddle/fluid/tests/unittests/test_communicator_half_async.py b/python/paddle/fluid/tests/unittests/test_communicator_half_async.py index 8a7904db95f7a1b8088197fdf16969e1ccfefae2..31b476eac0566f962ea452bf1fde62f5cb3c5169 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_half_async.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_half_async.py @@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase): current_id=0, role=role_maker.Role.WORKER if training_role == "TRAINER" else role_maker.Role.SERVER, - worker_num=2, + worker_num=1, server_endpoints=["127.0.0.1:6002"]) if training_role == "TRAINER":