diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 0621fa938c9f854ef1c906620f3e474c375efb8a..e2b09be5a9dfff0111ab80d89bdd76b99517738f 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" @@ -93,39 +95,22 @@ class CGenNCCLIdOp : public framework::OperatorBase { new RPCSERVER_T(endpoint, 1)); rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - distributed::RequestNotifyHandler notify_h( - distributed::DistributedMode::kSync, -1); - - rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - rpc_service->RegisterRPC(distributed::kRequestNotify, ¬ify_h); + rpc_h.SetRPCServer(rpc_service.get()); framework::ProgramDesc empty_program; framework::Executor executor(dev_ctx.GetPlace()); - - rpc_h.SetRPCServer(rpc_service.get()); rpc_h.SetScope(scope); rpc_h.SetDevCtx(&dev_ctx); rpc_h.SetProgram(&empty_program); rpc_h.SetExecutor(&executor); - notify_h.SetRPCServer(rpc_service.get()); - notify_h.SetScope(scope); - notify_h.SetDevCtx(&dev_ctx); - notify_h.SetProgram(&empty_program); - notify_h.SetExecutor(&executor); - - distributed::BarrierMonitor::Init(1); - auto* barrier = distributed::BarrierMonitor::GetInstance(); - barrier->Reset(1, distributed::BarrierType::kSendBarrier); - std::thread server_thread( std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); + rpc_service->SetCond(distributed::kRequestSend); VLOG(3) << "start getting nccl id from trainer 0..."; - barrier->WaitServerWeakup(); - barrier->ServerWeakup(); + rpc_service->WaitBarrier(distributed::kRequestSend); VLOG(3) << "got nccl id and stop server..."; - barrier->Stop(); rpc_service->ShutDown(); VLOG(3) << "rpc server stopped"; server_thread.join(); @@ -138,6 +123,7 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces."); AddComment(R"DOC( CGenNCCLId operator + For trainer 0: generate a new UniqueId and send it to all the other trainers. For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. )DOC"); diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 3cf15fe94168590f31c488648c0e67a82b7d1102..5aa91733fe3ed1bfc51b47b331488ce2211be2fb 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -15,8 +15,6 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) -cc_library(barrier_monitor SRCS barrier_monitor.cc DEPS enforce simple_threadpool trainer_desc_proto device_context) - # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") if(WITH_GRPC) @@ -28,7 +26,7 @@ if(WITH_GRPC) collective_client.cc collective_server.cc ${GRPC_SRCS} PROTO send_recv.proto - DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor barrier_monitor) + DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) diff --git a/paddle/fluid/operators/distributed/barrier_monitor.cc b/paddle/fluid/operators/distributed/barrier_monitor.cc deleted file mode 100644 index f6d82f5d8c3daea9b629c1937bfcbc5159cda461..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/distributed/barrier_monitor.cc +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/barrier_monitor.h" -#include - -#include -#include // NOLINT -#include -#include -#include -#include -#include -#include - -#include // NOLINT - -#include - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace distributed { -bool BarrierMonitor::IncreaseBarrier(const int worker_id, - const std::string &barrier) { - release_ = false; - - if (barrier == BATCH_BARRIER_MESSAGE) { - VLOG(4) << "BarrierMonitor send queue recv trainer: " << worker_id; - send_barrier_queue->Push(worker_id); - } else if (barrier == FETCH_BARRIER_MESSAGE) { - VLOG(4) << "BarrierMonitor recv queue recv trainer: " << worker_id; - recv_barrier_queue->Push(worker_id); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "unknown Message status %s, only " - "BATCH_BARRIER_MESSAGE/FETCH_BARRIER_MESSAGE", - barrier)); - } - return Wait(); -} - -void BarrierMonitor::DecreaseWorker() { - std::unique_lock lck(mutex_); - workers_--; - VLOG(1) << "decrement worker num to " << workers_; -} - -void BarrierMonitor::Reset(int workers, BarrierType type) { - std::unique_lock lk(server_mutex_); - - workers_ = workers; - barrier_type = type; - - send_barrier_queue->Clear(); - recv_barrier_queue->Clear(); - VLOG(2) << "reset monitor workers: " << workers_ << " type: " << barrier_type; -} - -void BarrierMonitor::Monitor() { - while (!IsReady() && running_) { - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - VLOG(3) << "sync at first time, wait all trainer ready"; - } - - while (running_) { - int timer = 0; - - if (IsReady()) { - Swap(true); - } else { - VLOG(4) << "running timer: " << timer << " barrier: " << barrier_type - << " sendQ:" << send_barrier_queue->Size() - << " recvQ: " << recv_barrier_queue->Size(); - - timer++; - if (max_wait_ms == -1 || timer < max_wait_ms) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } else { - VLOG(1) << "time out of " << max_wait_ms - << ", need barreir: " << barrier_type << " retry"; - Swap(false); - } - } - } -} - -bool BarrierMonitor::IsReady() { - if (barrier_type == BarrierType::kSendBarrier) { - return static_cast(send_barrier_queue->Size()) == workers_; - } else { - return static_cast(recv_barrier_queue->Size()) == workers_; - } -} - -void BarrierMonitor::Swap(bool is_valid) { - std::unique_lock lck(mutex_); - - valid_ = is_valid; - release_ = true; - - if (barrier_type == BarrierType::kSendBarrier) { - barrier_type = BarrierType::kRecvBarrier; - send_barrier_queue->Clear(); - VLOG(4) << "barrier monitor server clean up queue and barrier"; - ServerWeakup(); - VLOG(4) << "barrier monitor server weak up sync to do"; - WaitServerWeakup(); - VLOG(4) << "barrier monitor server weak up sync done"; - - } else { - barrier_type = BarrierType::kSendBarrier; - recv_barrier_queue->Clear(); - VLOG(4) << "barrier monitor server switch to send barrier"; - } - - worker_cv_.notify_all(); -} - -void BarrierMonitor::Stop() { - valid_ = true; - release_ = true; - running_ = false; - - barrier_type = BarrierType::kRecvBarrier; - send_barrier_queue->Clear(); - recv_barrier_queue->Clear(); - - worker_cv_.notify_all(); - server_cv_.notify_all(); - - if (monitor_thread_) monitor_thread_->join(); - monitor_thread_ = nullptr; -} - -bool BarrierMonitor::Wait() { - std::unique_lock lk(mutex_); - worker_cv_.wait(lk, [this] { return (release_); }); - return valid_; -} - -void BarrierMonitor::WaitServerWeakup() { - std::unique_lock lk(server_mutex_); - server_cv_.wait(lk); -} - -void BarrierMonitor::ServerWeakup() { server_cv_.notify_all(); } - -std::once_flag BarrierMonitor::init_flag_; -std::unique_ptr BarrierMonitor::monitor_(nullptr); - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/barrier_monitor.h b/paddle/fluid/operators/distributed/barrier_monitor.h deleted file mode 100644 index f9556d7720f7a7ebcadcc1f86ad6051786777041..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/distributed/barrier_monitor.h +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include // NOLINT -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include -#include - -#include // NOLINT - -#include - -#include "paddle/fluid/operators/distributed/rpc_server.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace distributed { - -enum BarrierType { kSendBarrier, kRecvBarrier }; - -constexpr int64_t kMaxWaitMS = 120000; - -template -class BlockingQueueForBarrier { - public: - explicit BlockingQueueForBarrier(size_t capacity) : capacity_(capacity) { - PADDLE_ENFORCE_GT(capacity_, 0, - platform::errors::InvalidArgument( - "The capacity must be greater than 0.")); - } - - bool Push(const T &elem) { - { - std::unique_lock lock(mutex_); - worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); - queue_.push_back(elem); - } - worker_cv_.notify_one(); - return true; - } - - bool Push(T &&elem) { - { - std::unique_lock lock(mutex_); - worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); - queue_.emplace_back(std::move(elem)); - } - worker_cv_.notify_one(); - return true; - } - - T Pop() { - std::unique_lock lock(mutex_); - worker_cv_.wait(lock, [=] { return !queue_.empty(); }); - T rc(std::move(queue_.front())); - queue_.pop_front(); - worker_cv_.notify_one(); - return rc; - } - - size_t Cap() const { - std::lock_guard lock(mutex_); - return capacity_; - } - - size_t Size() const { - std::lock_guard lock(mutex_); - return queue_.size(); - } - - void Clear() { - std::lock_guard lock(mutex_); - std::deque().swap(queue_); - } - - private: - const size_t capacity_; - std::deque queue_; - - mutable std::mutex mutex_; - std::condition_variable worker_cv_; -}; - -class BarrierMonitor { - public: - explicit BarrierMonitor(int workers) - : BarrierMonitor(workers, BarrierType::kRecvBarrier, kMaxWaitMS) {} - - explicit BarrierMonitor(int workers, BarrierType type, int64_t max_wait_times) - : workers_(workers), barrier_type(type), max_wait_ms(max_wait_times) { - PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument( - "trainers must have one or more")); - - send_barrier_queue = - std::make_shared>(workers); - recv_barrier_queue = - std::make_shared>(workers); - - running_ = true; - monitor_thread_.reset( - new std::thread(std::bind(&BarrierMonitor::Monitor, this))); - } - - static BarrierMonitor *Init(int workers) { - InitImpl(workers); - return GetInstance(); - } - - static BarrierMonitor *GetInstance() { return monitor_.get(); } - - bool IncreaseBarrier(const int worker_id, const std::string &barrier); - - void DecreaseWorker(); - - int GetWorkerNum() { return workers_; } - - void Monitor(); - - void Swap(bool is_valid); - - void Stop(); - - bool IsReady(); - - bool Wait(); - - void WaitServerWeakup(); - - void ServerWeakup(); - - void WorkerWeakup(); - - void Reset(int workers, BarrierType type); - - private: - // Init is called by GetInstance. - static void InitImpl(int workers) { - monitor_.reset(new BarrierMonitor(workers)); - } - - static std::once_flag init_flag_; - static std::unique_ptr monitor_; - - int workers_; - bool running_ = false; - bool valid_ = false; - bool release_ = false; - - std::condition_variable worker_cv_; - std::condition_variable server_cv_; - - std::mutex server_mutex_; - std::mutex mutex_; - - BarrierType barrier_type; - int64_t max_wait_ms; - std::unique_ptr monitor_thread_{nullptr}; - std::shared_ptr> send_barrier_queue; - std::shared_ptr> recv_barrier_queue; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index ca93f7eb958cde66b933612f05bdfc2965cd2a75..0652f8691218dc688732bd4243315b188cd0b053 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -260,7 +260,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, while (true) { GetProcessor* s = new GetProcessor(ch); VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); - s->Prepare(h, time_out); + s->Prepare(h, kPrefetchTimeout); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, s, method, h, table_name_val, this] { @@ -306,19 +306,52 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - platform::CPUDeviceContext ctx; - auto* scope = new framework::Scope(); - auto h = AsyncDistributeNotify(ep, ctx, *scope, BATCH_BARRIER_MESSAGE); - delete scope; + const auto ch = GetChannel(ep); + + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + const std::string method = kBatchBarrierRPC; + VarHandlePtr h( + new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr)); + s->Prepare(h, time_out); + + sendrecv::VariableMessage req; + req.set_varname(BATCH_BARRIER_MESSAGE); + + platform::RecordRPCEvent record_event(method); + + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + return h; } VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - platform::CPUDeviceContext ctx; - auto* scope = new framework::Scope(); - auto h = AsyncDistributeNotify(ep, ctx, *scope, FETCH_BARRIER_MESSAGE); - delete scope; + const auto ch = GetChannel(ep); + FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); + const std::string method = kFetchBarrierRPC; + VarHandlePtr h( + new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); + s->Prepare(h, time_out); + + sendrecv::VariableMessage req; + req.set_varname(FETCH_BARRIER_MESSAGE); + + platform::RecordRPCEvent record_event(method); + + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + return h; } @@ -351,10 +384,27 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { - platform::CPUDeviceContext ctx; - auto* scope = new framework::Scope(); - auto h = AsyncDistributeNotify(ep, ctx, *scope, COMPLETE_MESSAGE); - delete scope; + const auto ch = GetChannel(ep); + + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + const std::string method = kSendCompleteRPC; + VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr)); + s->Prepare(h, time_out); + + sendrecv::VariableMessage req; + req.set_trainer_id(trainer_id_); + req.set_varname(COMPLETE_MESSAGE); + + platform::RecordRPCEvent record_event(method); + + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + return h; } @@ -404,21 +454,10 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( s->Prepare(h, time_out); framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { - ::grpc::ByteBuffer buf; + auto* var = p_scope->FindVar(var_name_val); - if (var_name_val == BATCH_BARRIER_MESSAGE || - var_name_val == FETCH_BARRIER_MESSAGE || - var_name_val == COMPLETE_MESSAGE) { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(var_name_val); - req.set_trainer_id(trainer_id_); - RequestToByteBuffer(req, &buf); - } else { - auto* var = p_scope->FindVar(var_name_val); - SerializeToByteBuffer(var_name_val, var, *p_ctx, &buf, "", trainer_id_); - } + ::grpc::ByteBuffer req; + SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; @@ -428,7 +467,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( platform::RecordRPCEvent record_event(method); auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", buf, + s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req, &cq_); call->StartCall(); call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); @@ -448,6 +487,18 @@ bool GRPCClient::Wait() { return ok_; } +inline bool ShouldRetry(const std::string& method, int error_code) { + if (method == kPrefetchRPC) { + return true; + } + + if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) { + return true; + } + + return false; +} + void GRPCClient::Proceed() { void* tag = nullptr; bool ok = false; @@ -461,19 +512,9 @@ void GRPCClient::Proceed() { if (c->status_.ok()) { VLOG(3) << c->GetVarHandlePtr()->String() << " process"; c->Process(); - } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { - PADDLE_THROW(platform::errors::External( - "%s meets grpc error, error_code is %d, error message is %s, error " - "details is %s.", - c->GetVarHandlePtr()->String(), c->status_.error_code(), - c->status_.error_message(), c->status_.error_details())); - { - std::lock_guard lk(sync_mutex_); - ok_ = false; - } - c->Finish(false); - } else if (c->status_.error_code() == grpc::StatusCode::UNAVAILABLE) { - VLOG(3) << c->GetVarHandlePtr()->String() + } else if (ShouldRetry(c->GetVarHandlePtr()->method(), + c->status_.error_code())) { + VLOG(0) << c->GetVarHandlePtr()->String() << " meets grpc error, error_code:" << c->status_.error_code() << " error_message:" << c->status_.error_message() << " error_details:" << c->status_.error_details() diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 74d9fc78cedc25ea64f684b6aed830021fbbd5cc..7cccf259b596f2116d14b23d19dba6df229d3cd7 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -57,6 +57,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC"; constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC"; constexpr char kSendCompleteRPC[] = "SendCompleteRPC"; constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; +constexpr int64_t kPrefetchTimeout = 60000; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 5871bd14fc8033ea50c829e99b40fc2322033b16..0205bab0504d75df4e2b8bf15326a8aec9127544 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -28,7 +28,6 @@ #include "paddle/fluid/string/split.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" -#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" namespace paddle { @@ -39,130 +38,161 @@ namespace distributed { // to directory specified. constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; -bool RequestSendHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, +bool RequestSendHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(4) << "RequestSendHandler:" << varname; - if (invar == nullptr) { - PADDLE_THROW(platform::errors::NotFound( - "sync: Can not find server side var: %s", varname)); - return false; - } + // Sync + if (varname == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; + rpc_server_->IncreaseBatchBarrier(kRequestSend); + } else if (varname == COMPLETE_MESSAGE) { + VLOG(3) << "sync: recv complete message"; - if (distributed_mode_ == DistributedMode::kSync) { - return true; - } + if (HeartBeatMonitor::GetInstance() != nullptr) { + HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); + } - HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); + rpc_server_->Complete(); + } else { + // Async + if (distributed_mode_ != DistributedMode::kSync) { + VLOG(3) << "async process var: " << varname; + if (varname == BATCH_BARRIER_MESSAGE) { + PADDLE_THROW( + "async mode should not recv BATCH_BARRIER_MESSAGE or " + "COMPLETE_MESSAGE"); + } + HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); - std::string run_varname = varname; - string::Piece part_piece("@PIECE"); - string::Piece var_name_piece = string::Piece(varname); + std::string run_varname = varname; - if (string::Contains(var_name_piece, part_piece)) { - auto varname_splits = paddle::string::Split(varname, '@'); - run_varname = varname_splits[0]; - scope->Rename(varname, run_varname); - } + string::Piece part_piece("@PIECE"); + string::Piece var_name_piece = string::Piece(varname); - if (distributed_mode_ == DistributedMode::kGeo && - AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) { - auto &grad_slr = - scope->FindVar(run_varname)->Get(); - AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, - grad_slr.rows()); - } + if (string::Contains(var_name_piece, part_piece)) { + auto varname_splits = paddle::string::Split(varname, '@'); + PADDLE_ENFORCE_EQ(varname_splits.size(), 3); + run_varname = varname_splits[0]; + scope->Rename(varname, run_varname); + } - executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), - scope); + if (distributed_mode_ == DistributedMode::kGeo && + AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) { + auto& grad_slr = + scope->FindVar(run_varname)->Get(); + AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, + grad_slr.rows()); + } + executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), + scope); + return true; + } else { // sync + rpc_server_->WaitCond(kRequestSend); + VLOG(3) << "sync: processing received var: " << varname; + PADDLE_ENFORCE_NOT_NULL( + invar, platform::errors::NotFound( + "sync: Can not find server side var %s.", varname)); + } + } return true; } -bool RequestGetHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, +bool RequestGetHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(3) << "RequestGetHandler:" << varname << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id << " table_name: " << table_name; if (distributed_mode_ == DistributedMode::kSync) { - *outvar = scope_->FindVar(varname); - } else { - if (enable_dc_asgd_) { - // NOTE: the format is determined by distribute_transpiler.py - std::string param_bak_name = - string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); - VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; - auto var = scope_->FindVar(varname); - auto t_orig = var->Get(); - auto param_bak = scope_->Var(param_bak_name); - auto t = param_bak->GetMutable(); - t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); - VLOG(3) << "copying " << varname << " to " << param_bak_name; - framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); + if (varname == FETCH_BARRIER_MESSAGE) { + VLOG(3) << "sync: recv fetch barrier message"; + rpc_server_->IncreaseBatchBarrier(kRequestGet); + } else { + rpc_server_->WaitCond(kRequestGet); + *outvar = scope_->FindVar(varname); } - - if (distributed_mode_ == DistributedMode::kGeo && - AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && - !table_name.empty()) { - std::vector updated_rows; - AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( - varname, trainer_id, &updated_rows); - if (VLOG_IS_ON(3)) { - std::ostringstream sstream; - sstream << "["; - for (auto &row_id : updated_rows) { - sstream << row_id << ", "; - } - sstream << "]"; - VLOG(3) << "updated_rows size: " << updated_rows.size() << " " - << sstream.str(); + } else { + if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { + if (enable_dc_asgd_) { + // NOTE: the format is determined by distribute_transpiler.py + std::string param_bak_name = + string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); + VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; + auto var = scope_->FindVar(varname); + auto t_orig = var->Get(); + auto param_bak = scope_->Var(param_bak_name); + auto t = param_bak->GetMutable(); + t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); + VLOG(3) << "copying " << varname << " to " << param_bak_name; + framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); } - auto &origin_tensor = - scope_->FindVar(varname)->Get(); - auto *origin_tensor_data = origin_tensor.data(); - auto &dims = origin_tensor.dims(); - *outvar = scope->Var(); - auto *out_slr = (*outvar)->GetMutable(); - out_slr->set_rows(updated_rows); - out_slr->set_height(dims[0]); - auto out_dims = framework::make_ddim( - {static_cast(updated_rows.size()), dims[1]}); - auto *data = out_slr->mutable_value()->mutable_data( - out_dims, origin_tensor.place()); - auto width = dims[1]; - for (size_t i = 0; i < updated_rows.size(); ++i) { - PADDLE_ENFORCE_LT(updated_rows[i], dims[0], - platform::errors::OutOfRange( - "expected >= 0 and < %ld, but got %ld.", dims[0], - updated_rows[i])); - memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, - sizeof(float) * width); + VLOG(1) << "Table name empty? " << table_name.empty(); + if (distributed_mode_ == DistributedMode::kGeo) { + VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist " + << AsyncSparseParamUpdateRecorder::GetInstance()->HasParam( + varname); + } + if (distributed_mode_ == DistributedMode::kGeo && + AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && + !table_name.empty()) { + std::vector updated_rows; + AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( + varname, trainer_id, &updated_rows); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& row_id : updated_rows) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "updated_rows size: " << updated_rows.size() << " " + << sstream.str(); + } + auto& origin_tensor = + scope_->FindVar(varname)->Get(); + auto* origin_tensor_data = origin_tensor.data(); + auto& dims = origin_tensor.dims(); + *outvar = scope->Var(); + auto* out_slr = (*outvar)->GetMutable(); + out_slr->set_rows(updated_rows); + out_slr->set_height(dims[0]); + auto out_dims = framework::make_ddim( + {static_cast(updated_rows.size()), dims[1]}); + auto* data = out_slr->mutable_value()->mutable_data( + out_dims, origin_tensor.place()); + auto width = dims[1]; + for (size_t i = 0; i < updated_rows.size(); ++i) { + PADDLE_ENFORCE_LT(updated_rows[i], dims[0]); + memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, + sizeof(float) * width); + } + } else { + *outvar = scope_->FindVar(varname); } - } else { - *outvar = scope_->FindVar(varname); } } return true; } -bool RequestGetNoBarrierHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, +bool RequestGetNoBarrierHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(4) << "RequestGetNoBarrierHandler:" << varname << " out_var_name: " << out_var_name; @@ -177,19 +207,18 @@ bool RequestGetNoBarrierHandler::Handle(const std::string &varname, *outvar = scope_->FindVar(var_name_piece.ToString()); return true; } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE)); + PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE); } return true; } -bool RequestPrefetchHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, +bool RequestPrefetchHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(4) << "RequestPrefetchHandler " << varname; if (table_name.empty()) { @@ -207,20 +236,19 @@ bool RequestPrefetchHandler::Handle(const std::string &varname, return true; } -bool RequestCheckpointHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, +bool RequestCheckpointHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - PADDLE_ENFORCE_NE( - checkpoint_notify_id, -1, - platform::errors::Unavailable( - "when checkpoint_notify_id = -1, there should be no RPC invoke.")); + const std::string& out_var_name, + const std::string& table_name) { + PADDLE_ENFORCE( + checkpoint_notify_id != -1, + "when checkpoint_notify_id = -1, there should be no RPC invoke."); // TODO(tangwei12): find out why scope will be error. - auto *lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); + auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " @@ -229,56 +257,33 @@ bool RequestCheckpointHandler::Handle(const std::string &varname, return true; } -bool RequestNotifyHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, +bool RequestNotifyHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { + const std::string& out_var_name, + const std::string& table_name) { + VLOG(4) << "RequestNotifyHandler: " << varname; VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id; string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER); - string::Piece batch_piece(BATCH_BARRIER_MESSAGE); - string::Piece fetch_piece(FETCH_BARRIER_MESSAGE); - string::Piece complete_piece(COMPLETE_MESSAGE); - string::Piece var_name_piece = string::Piece(varname); - - if (string::Contains(var_name_piece, batch_piece)) { - return BarrierMonitor::GetInstance()->IncreaseBarrier( - trainer_id, BATCH_BARRIER_MESSAGE); - } else if (string::Contains(var_name_piece, fetch_piece)) { - return BarrierMonitor::GetInstance()->IncreaseBarrier( - trainer_id, FETCH_BARRIER_MESSAGE); - } else if (string::Contains(var_name_piece, complete_piece)) { - if (HeartBeatMonitor::GetInstance() != nullptr) { - HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); - } - rpc_server_->Complete(); - BarrierMonitor::GetInstance()->DecreaseWorker(); - return true; - } else if (string::Contains(var_name_piece, decay_piece)) { + if (string::Contains(var_name_piece, decay_piece)) { VLOG(3) << "LearningRate Decay Counter Update"; PADDLE_ENFORCE_NE( lr_decay_block_id, -1, - platform::errors::InvalidArgument( - "when lr_decay_block_id = -1, there should be no RPC invoke.")); - auto *origin_var = scope_->FindVar(varname); + "when lr_decay_block_id = -1, there should be no RPC invoke."); + auto* origin_var = scope_->FindVar(varname); auto origin_var_tensor = origin_var->Get(); - auto *send_var = scope->FindVar(varname); + auto* send_var = scope->FindVar(varname); auto send_var_tensor = send_var->Get(); - int64_t *origin_value = + int64_t* origin_value = origin_var_tensor.mutable_data(origin_var_tensor.place()); - int64_t *send_value = + int64_t* send_value = send_var_tensor.mutable_data(send_var_tensor.place()); origin_value[0] += send_value[0]; executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); - - return true; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unkown varname %s with RequestNotifyHandler", varname)); } return true; } diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index bc17c84645116df7868107a6acf3de620dd9f798..d36a433db7dda89b5a9edb6fb8db8552ecce7854 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" @@ -117,7 +119,6 @@ void StartServer(const std::string& rpc_name) { g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); distributed::HeartBeatMonitor::Init(2, true, "w@grad"); - distributed::BarrierMonitor::Init(2); g_req_handler->SetRPCServer(g_rpc_service.get()); @@ -163,9 +164,6 @@ TEST(PREFETCH, CPU) { } } - auto* barrier = distributed::BarrierMonitor::GetInstance(); - barrier->Stop(); - g_rpc_service->ShutDown(); server_thread.join(); LOG(INFO) << "begin reset"; @@ -176,24 +174,20 @@ TEST(PREFETCH, CPU) { TEST(COMPLETE, CPU) { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); - g_req_handler.reset(new distributed::RequestNotifyHandler( - distributed::DistributedMode::kSync, -1)); + g_req_handler.reset( + new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); - distributed::RPCClient* client = distributed::RPCClient::GetInstance(0); PADDLE_ENFORCE(client != nullptr); - std::thread server_thread(StartServer, distributed::kRequestNotify); + std::thread server_thread(StartServer, distributed::kRequestSend); g_rpc_service->WaitServerReady(); int port = g_rpc_service->GetSelectedPort(); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); client->AsyncSendComplete(ep); client->Wait(); - auto* barrier = distributed::BarrierMonitor::GetInstance(); - EXPECT_EQ(barrier->GetWorkerNum(), 1); - - barrier->Stop(); + EXPECT_EQ(g_rpc_service->GetClientNum(), 1); g_rpc_service->ShutDown(); server_thread.join(); diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 244d3ece48ecc201465a6badeb5cd44bbf71f4a8..79f14d75d279d0ae1a68bf857ab9f46d6b71c42f 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc index cf8322905297156ba5e36c5b21e009739daa194f..e63f882478351cde16bde969b86e020181d6d4e5 100644 --- a/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -28,16 +30,16 @@ namespace operators { class GenNCCLIdOp : public framework::OperatorBase { public: - GenNCCLIdOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); // put nccl id in CPUPlace - auto &dev_ctx = *pool.Get(platform::CPUPlace()); + auto& dev_ctx = *pool.Get(platform::CPUPlace()); int trainer_id = Attr("trainer_id"); std::vector trainers = @@ -53,7 +55,7 @@ class GenNCCLIdOp : public framework::OperatorBase { std::string endpoint = trainers[trainer_id]; - framework::Scope &local_scope = scope.NewScope(); + framework::Scope& local_scope = scope.NewScope(); int nccl_comm_num = Attr("nccl_comm_num"); int use_hierarchical_allreduce = Attr("use_hierarchical_allreduce"); @@ -169,10 +171,10 @@ class GenNCCLIdOp : public framework::OperatorBase { } private: - void GenerateAndSend(framework::Scope *scope, - const platform::DeviceContext &dev_ctx, - const std::string &nccl_id_name, - const std::vector &endpoint_list) const { + void GenerateAndSend(framework::Scope* scope, + const platform::DeviceContext& dev_ctx, + const std::string& nccl_id_name, + const std::vector& endpoint_list) const { auto var = scope->FindVar(nccl_id_name); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::NotFound("Variable with name %s is not found", @@ -180,96 +182,76 @@ class GenNCCLIdOp : public framework::OperatorBase { auto id = var->GetMutable(); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id)); - distributed::RPCClient *client = + distributed::RPCClient* client = distributed::RPCClient::GetInstance(0); - for (auto &ep : endpoint_list) { + for (auto& ep : endpoint_list) { VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep; client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name); } client->Wait(); - for (auto &ep : endpoint_list) { + for (auto& ep : endpoint_list) { client->AsyncSendBatchBarrier(ep); } client->Wait(); VLOG(3) << "sending completed..."; } - void GetIdByServer(const std::string &endpoint, framework::Scope *scope, - const platform::DeviceContext &dev_ctx, int nccl_comm_num, + void GetIdByServer(const std::string& endpoint, framework::Scope* scope, + const platform::DeviceContext& dev_ctx, int nccl_comm_num, bool use_hierarchical_allreduce, int trainer_id, int inter_trainer_id, int exter_trainer_id) const { // std::string endpoint = Attr("endpoint"); // NOTE: Can not use unique_ptr here because the default // deleter will call GRPC Server's base class's dtor and // that will cause a wired crash. - + distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync); std::unique_ptr rpc_service( new RPCSERVER_T(endpoint, 1)); - distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync); - - distributed::RequestNotifyHandler notify_h( - distributed::DistributedMode::kSync, -1); - rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - rpc_service->RegisterRPC(distributed::kRequestNotify, ¬ify_h); + rpc_h.SetRPCServer(rpc_service.get()); framework::ProgramDesc empty_program; framework::Executor executor(dev_ctx.GetPlace()); - - rpc_h.SetRPCServer(rpc_service.get()); rpc_h.SetScope(scope); rpc_h.SetDevCtx(&dev_ctx); rpc_h.SetProgram(&empty_program); rpc_h.SetExecutor(&executor); - notify_h.SetRPCServer(rpc_service.get()); - notify_h.SetScope(scope); - notify_h.SetDevCtx(&dev_ctx); - notify_h.SetProgram(&empty_program); - notify_h.SetExecutor(&executor); - - distributed::BarrierMonitor::Init(1); - auto *barrier = distributed::BarrierMonitor::GetInstance(); - barrier->Reset(1, distributed::BarrierType::kSendBarrier); - std::thread server_thread( std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); for (int i = 0; i < nccl_comm_num; i++) { - barrier->WaitServerWeakup(); - barrier->Reset(1, distributed::BarrierType::kSendBarrier); - barrier->ServerWeakup(); - + rpc_service->SetCond(distributed::kRequestSend); VLOG(3) << "trainer_id:" << trainer_id << " start getting nccl id from trainer 0, nccl_comm_no:" << i; + rpc_service->WaitBarrier(distributed::kRequestSend); + rpc_service->ResetBarrierCounter(); } if (use_hierarchical_allreduce) { if (inter_trainer_id > 0) { for (int i = 0; i < nccl_comm_num; i++) { - barrier->WaitServerWeakup(); - barrier->Reset(1, distributed::BarrierType::kSendBarrier); - barrier->ServerWeakup(); - + rpc_service->SetCond(distributed::kRequestSend); VLOG(3) << "trainer_id:" << trainer_id << ", inter_trainer_id:" << inter_trainer_id << " start getting nccl id from inter_trainer:" << i; + rpc_service->WaitBarrier(distributed::kRequestSend); + rpc_service->ResetBarrierCounter(); } } if (exter_trainer_id > 0) { for (int i = 0; i < nccl_comm_num; i++) { - barrier->WaitServerWeakup(); - barrier->Reset(1, distributed::BarrierType::kSendBarrier); - barrier->ServerWeakup(); - + rpc_service->SetCond(distributed::kRequestSend); VLOG(3) << "trainer_id:" << trainer_id << ", exter_trainer_id:" << exter_trainer_id << " start getting nccl id from exter_trainer 0, nccl_comm_no:" << i; + rpc_service->WaitBarrier(distributed::kRequestSend); + rpc_service->ResetBarrierCounter(); } } } @@ -278,7 +260,6 @@ class GenNCCLIdOp : public framework::OperatorBase { << ", inter_trainer_id:" << inter_trainer_id << ", exter_trainer_id:" << exter_trainer_id << " got nccl id and stop server..."; - barrier->Stop(); rpc_service->ShutDown(); VLOG(3) << "rpc server stopped"; server_thread.join(); @@ -291,6 +272,7 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); AddComment(R"DOC( GenNCCLId operator + For trainer 0: generate a new UniqueId and send it to all the other trainers. For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. )DOC"); diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index c8c0316e74739622f46cb577ae051fc88dd39bb7..d40df6f9de0c1e22ea892993d66a2cdfa808b1c7 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,7 +25,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" -#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" @@ -36,13 +38,10 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch"); namespace paddle { namespace operators { -volatile sig_atomic_t gSignalStatus; - void RunServer(std::shared_ptr service) { service->StartServer(); VLOG(4) << "RunServer thread end"; } - static void split(const std::string &str, char sep, std::vector *pieces) { pieces->clear(); @@ -127,7 +126,6 @@ void ListenAndServOp::RunSyncLoop( for (size_t i = 1; i < program->Size(); ++i) { optimize_blocks_list.push_back(i); } - auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list); // Insert placeholder for block0 which holds current op itself, // NOTE the first block in `optimize_prepared` should never be ran. @@ -137,15 +135,21 @@ void ListenAndServOp::RunSyncLoop( // Trainers will get all parameters from pserver in the // startup program, so we will wait RequestGet first - auto *barrier = distributed::BarrierMonitor::GetInstance(); + rpc_service_->SetCond(distributed::kRequestGet); + rpc_service_->WaitBarrier(distributed::kRequestGet); + rpc_service_->ResetBarrierCounter(); while (true) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - barrier->WaitServerWeakup(); + VLOG(3) << "wait all clients to send gradient"; + rpc_service_->SetCond(distributed::kRequestSend); + VLOG(3) << "wait all clients to send send_barrier"; + rpc_service_->WaitBarrier(distributed::kRequestSend); - if (gSignalStatus != 0) { + if (rpc_service_->IsExit()) { LOG(WARNING) << "get exit!rpc_processor break!"; + rpc_service_->SetCond(distributed::kRequestGet); break; } @@ -176,8 +180,12 @@ void ListenAndServOp::RunSyncLoop( VLOG(3) << "ResetReceivedVars"; ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); - barrier->ServerWeakup(); - VLOG(3) << "kRecvBarrier to push params to trainers"; + VLOG(3) << "wait all clients to get parameters back"; + rpc_service_->SetCond(distributed::kRequestGet); + VLOG(3) << "wait all clients to send fetch_barrier"; + rpc_service_->WaitBarrier(distributed::kRequestGet); + VLOG(3) << "ResetBarrierCounter"; + rpc_service_->ResetBarrierCounter(); } // while(true) } @@ -273,7 +281,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); while (true) { - if (gSignalStatus != 0) { + if (rpc_service_->IsExit()) { VLOG(4) << "get exit!rpc_processor break!"; break; } @@ -383,7 +391,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, request_get_no_barrier_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestNotify, - request_notify_handler_.get(), fan_in * 2); + request_notify_handler_.get(), rpc_send_thread_num); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -432,7 +440,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, std::unordered_map> prefetch_var_name_to_prepared_ctx; - for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) { auto block_id = prefetch_block_id_list[i]; auto prefetch_var_name = block_id_to_prefetch_var_name[block_id]; @@ -441,10 +448,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // parse attr of kSparseGradToParam sparse_grad_name -> param_name std::unordered_map sparse_grad_name_to_param_name; - auto sparse_grad_name_to_param_name_str = Attr>(kSparseGradToParam); - for (const auto &sparse_grad_name_and_param_name : sparse_grad_name_to_param_name_str) { std::vector pieces; @@ -472,18 +477,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, signal(SIGINT, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit); - distributed::BarrierMonitor::Init(fan_in); - if (distributed_mode == distributed::DistributedMode::kSync) { // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); VLOG(3) << "wait server thread to become ready..."; rpc_service_->WaitServerReady(); - // Write to a file of server selected port for python use. - SavePort(); CacheVarsType(inputs, recv_scope); + // Write to a file of server selected port for python use. + SavePort(); + RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, prefetch_block_id_list, checkpoint_block_id); } else { @@ -570,8 +574,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { void SignalHandler::StopAndExit(int signal_num) { // Do not use VLOG here for the device for printing maybe already released. // exit will release interal allocated resoureces. - distributed::BarrierMonitor::GetInstance()->Stop(); - gSignalStatus = signal_num; + auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); + remove(file_path.c_str()); + exit(0); } } // namespace operators diff --git a/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc b/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc index 3f5fdadc22342bf17f54d86e39bdc5114915c001..b65621a0886b02fd8d3c029c979348469014cadc 100644 --- a/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc +++ b/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +20,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/distributed/barrier_monitor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" @@ -40,7 +42,6 @@ namespace string = paddle::string; std::unique_ptr g_rpc_service; std::unique_ptr g_req_handler; -std::unique_ptr g_notify_handler; void StartServer() { f::Scope scope; @@ -51,35 +52,21 @@ void StartServer() { f::ProgramDesc empty_program; f::Executor executor(dev_ctx.GetPlace()); - g_req_handler->SetScope(&scope); g_req_handler->SetDevCtx(&dev_ctx); g_req_handler->SetProgram(&empty_program); g_req_handler->SetExecutor(&executor); - g_req_handler->SetRPCServer(g_rpc_service.get()); - - g_notify_handler.SetRPCServer(rpc_service.get()); - g_notify_handler.SetScope(scope); - g_notify_handler.SetDevCtx(&dev_ctx); - g_notify_handler.SetProgram(&empty_program); - g_notify_handler.SetExecutor(&executor); g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get()); - g_rpc_service->RegisterRPC(distributed::RequestNotifyHandler, - g_notify_handler.get()); - - distributed::BarrierMonitor::Init(1); - auto* barrier = distributed::BarrierMonitor::GetInstance(); - barrier->Reset(1, distributed::BarrierType::kSendBarrier); + g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); - barrier->WaitServerWeakup(); - barrier->ServerWeakup(); + g_rpc_service->SetCond(distributed::kRequestSend); + g_rpc_service->WaitBarrier(distributed::kRequestSend); LOG(INFO) << "got nccl id and stop server..."; - barrier->Stop(); g_rpc_service->ShutDown(); server_thread.join(); } @@ -87,10 +74,6 @@ void StartServer() { TEST(SendNcclId, RPCServer) { g_req_handler.reset( new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); - - g_notify_handler.reset(new distributed::RequestNotifyHandler( - distributed::DistributedMode::kSync, -1)); - g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); std::thread server_thread(StartServer); @@ -121,5 +104,4 @@ TEST(SendNcclId, RPCServer) { server_thread.join(); g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); - g_notify_handler.reset(nullptr); } diff --git a/python/paddle/fluid/tests/unittests/test_communicator_half_async.py b/python/paddle/fluid/tests/unittests/test_communicator_half_async.py index 31b476eac0566f962ea452bf1fde62f5cb3c5169..8a7904db95f7a1b8088197fdf16969e1ccfefae2 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_half_async.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_half_async.py @@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase): current_id=0, role=role_maker.Role.WORKER if training_role == "TRAINER" else role_maker.Role.SERVER, - worker_num=1, + worker_num=2, server_endpoints=["127.0.0.1:6002"]) if training_role == "TRAINER": diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 381c8146431c73b63e34c8805145627981fc239a..ba292f2d87c376ace317fc3fb9b81ce5c5596eb2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -937,11 +937,6 @@ class TestDistBase(unittest.TestCase): need_envs={}, log_name=""): - print( - "disable distributed unittests temporary, will enable it soon. (tangwei)" - ) - return - required_envs = self._get_required_envs(check_error_log, need_envs) local_losses \ @@ -982,11 +977,6 @@ class TestDistBase(unittest.TestCase): need_envs={}, log_name=""): - print( - "disable distributed unittests temporary, will enable it soon. (tangwei)" - ) - return - # need open p2p or shm otherwise multi cards mode will hang need_envs.update({"NCCL_P2P_DISABLE": "0", "NCCL_SHM_DISABLE": "0"}) diff --git a/python/paddle/fluid/tests/unittests/test_fl_listen_and_serv_op.py b/python/paddle/fluid/tests/unittests/test_fl_listen_and_serv_op.py deleted file mode 100644 index de6b48e2cec602f8f73d7cf5f3f9b1fc66d55be6..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/test_fl_listen_and_serv_op.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""test f1 listen and serv_op.""" - -from __future__ import print_function - -import paddle -import paddle.fluid as fluid -from paddle.fluid import Program -import os -import signal -import subprocess -import time -import unittest -from multiprocessing import Process -from op_test import OpTest -import numpy -import urllib -import sys -from dist_test_utils import * - -cache_path = os.path.expanduser('~/.cache/paddle/dataset') - - -def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id): - ''' - This function is run trainer. - Args: - use_cuda (bool): whether use cuda. - sync_mode (nouse): specify sync mode. - ip (string): the ip address. - port (string): the port for listening. - trainers (int): the count of trainer. - trainer_id (int): the id of trainer. - - Returns: - None - ''' - x = fluid.layers.data(name='x', shape=[1], dtype='float32') - y_predict = fluid.layers.fc(input=x, size=1, act=None) - y = fluid.layers.data(name='y', shape=[1], dtype='float32') - # loss function - cost = fluid.layers.square_error_cost(input=y_predict, label=y) - avg_cost = fluid.layers.mean(cost) - # optimizer - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer.minimize(avg_cost) - with open("{}/trainer_recv_program.dms".format(cache_path), "rb") as f: - trainer_recv_program_desc_str = f.read() - with open("{}/trainer_main_program.dms".format(cache_path), "rb") as f: - trainer_main_program_desc_str = f.read() - with open("{}/trainer_send_program.dms".format(cache_path), "rb") as f: - trainer_send_program_desc_str = f.read() - recv_program = Program.parse_from_string(trainer_recv_program_desc_str) - main_program = Program.parse_from_string(trainer_main_program_desc_str) - send_program = Program.parse_from_string(trainer_send_program_desc_str) - - trainer_startup_program = fluid.default_startup_program() - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - - exe.run(trainer_startup_program) - for i in range(5): - exe.run(recv_program) - exe.run(fluid.default_main_program(), - feed={ - "x": numpy.array([1, 2]).astype('float32').reshape(2, 1), - "y": numpy.array([2, 3]).astype('float32').reshape(2, 1) - }) - exe.run(send_program) - - -def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id): - ''' - This function is run trainer. - Args: - use_cuda (bool): whether use cuda. - sync_mode (nouse): specify sync mode. - ip (string): the ip address. - port (string): the port for listening. - trainers (int): the count of trainer. - trainer_id (int): the id of trainer. - - Returns: - None - ''' - remove_ps_flag(os.getpid()) - x = fluid.layers.data(name='x', shape=[1], dtype='float32') - y_predict = fluid.layers.fc(input=x, size=1, act=None) - y = fluid.layers.data(name='y', shape=[1], dtype='float32') - # loss function - cost = fluid.layers.square_error_cost(input=y_predict, label=y) - avg_cost = fluid.layers.mean(cost) - # optimizer - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer.minimize(avg_cost) - with open("{}/pserver_startup_program.dms".format(cache_path), "rb") as f: - pserver_startup_program_desc_str = f.read() - with open("{}/pserver_main_program.dms".format(cache_path), "rb") as f: - pserver_main_program_desc_str = f.read() - - startup_program = Program.parse_from_string( - pserver_startup_program_desc_str) - main_program = Program.parse_from_string(pserver_main_program_desc_str) - - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(startup_program) - exe.run(main_program) - - -class TestFlListenAndServOp(unittest.TestCase): - """This class is Test Fl Listen And ServOp.""" - - def setUp(self): - """This function si set Up.""" - self.ps_timeout = 5 - self.ip = "127.0.0.1" - self.port = "6000" - self.trainers = 2 - self.trainer_id = 0 - - def _start_pserver(self, use_cuda, sync_mode, pserver_func): - """This function is start pserver.""" - p = Process( - target=pserver_func, - args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, - self.trainer_id)) - p.daemon = True - p.start() - return p - - def _start_trainer0(self, use_cuda, sync_mode, pserver_func): - """This function is start trainer0.""" - p = Process( - target=pserver_func, - args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, 0)) - p.daemon = True - p.start() - return p - - def _start_trainer1(self, use_cuda, sync_mode, pserver_func): - """This function is start trainer1.""" - p = Process( - target=pserver_func, - args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, 1)) - p.daemon = True - p.start() - return p - - def _wait_ps_ready(self, pid): - """This function is wait ps ready.""" - start_left_time = self.ps_timeout - sleep_time = 0.5 - while True: - assert start_left_time >= 0, "wait ps ready failed" - time.sleep(sleep_time) - try: - os.stat("/tmp/paddle.%d.port" % pid) - return - except os.error: - start_left_time -= sleep_time - - def test_rpc_interfaces(self): - """TODO(Yancey1989): need to make sure the rpc interface correctly.""" - # TODO(Yancey1989): need to make sure the rpc interface correctly. - pass - - def test_handle_signal_in_serv_op(self): - """run pserver on CPU in sync mode.""" - # run pserver on CPU in sync mode - if sys.platform == 'win32' or sys.platform == 'sys.platform': - pass - else: - print(sys.platform) - file_list = [ - 'pserver_startup_program.dms', 'pserver_main_program.dms', - 'trainer_recv_program.dms', 'trainer_main_program.dms', - 'trainer_send_program.dms' - ] - if not os.path.exists(cache_path): - os.makedirs(cache_path) - prefix = 'wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/' - for f in file_list: - if not os.path.exists('{}/{}'.format(cache_path, f)): - cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/{} -P {}/".format( - f, cache_path) - os.system(cmd) - p1 = self._start_pserver(False, True, run_pserver) - self._wait_ps_ready(p1.pid) - time.sleep(5) - t1 = self._start_trainer0(False, True, run_trainer) - time.sleep(2) - t2 = self._start_trainer1(False, True, run_trainer) - # raise SIGTERM to pserver - time.sleep(2) - cmd_del = "rm trainer*dms* pserver*dms*" - os.system(cmd_del) - os.kill(p1.pid, signal.SIGINT) - p1.join() - os.kill(t1.pid, signal.SIGINT) - t1.join() - os.kill(t2.pid, signal.SIGINT) - t2.join() - - -if __name__ == '__main__': - unittest.main()