From 84220765a73e68c1a817fb1fd3c7806814a83e7e Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 25 Jan 2019 14:48:58 +0800 Subject: [PATCH] refine code, add more log --- .../distributed/request_handler_impl.cc | 37 ++++++++++++------- .../fluid/operators/distributed/rpc_server.cc | 9 +++-- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 913ae76b3..2ed468331 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname, // Async if (!sync_mode_) { VLOG(3) << "async process var: " << varname; + if (varname == BATCH_BARRIER_MESSAGE || varname == COMPLETE_MESSAGE) { + PADDLE_THROW( + "async mode should not recv BATCH_BARRIER_MESSAGE or " + "COMPLETE_MESSAGE"); + } try { executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); @@ -95,21 +100,25 @@ bool RequestGetHandler::Handle(const std::string& varname, } } else { if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { - if (enable_dc_asgd_) { - // NOTE: the format is determined by distributed_transpiler.py - std::string param_bak_name = - string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); - VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; - auto var = scope_->FindVar(varname); - auto t_orig = var->Get(); - auto param_bak = scope_->Var(param_bak_name); - auto t = param_bak->GetMutable(); - t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); - VLOG(3) << "copying " << varname << " to " << param_bak_name; - framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); - } - *outvar = scope_->FindVar(varname); + PADDLE_THROW( + "async mode should not send FETCH_BARRIER_MESSAGE or " + "COMPLETE_MESSAGE"); + } + + if (enable_dc_asgd_) { + // NOTE: the format is determined by distributed_transpiler.py + std::string param_bak_name = + string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); + VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; + auto var = scope_->FindVar(varname); + auto t_orig = var->Get(); + auto param_bak = scope_->Var(param_bak_name); + auto t = param_bak->GetMutable(); + t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); + VLOG(3) << "copying " << varname << " to " << param_bak_name; + framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); } + *outvar = scope_->FindVar(varname); } return true; } diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index baf6b73b0..90733fd09 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -43,15 +43,15 @@ void RPCServer::SavePort() const { } void RPCServer::WaitBarrier(const std::string& rpc_name) { - VLOG(3) << "WaitBarrier: " << rpc_name; + VLOG(3) << "WaitBarrier in: " << rpc_name; std::unique_lock lock(this->mutex_); barrier_cond_.wait(lock, [this, &rpc_name] { return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || exit_flag_.load()); }); - VLOG(3) << "batch_barrier_: " << rpc_name << " " - << barrier_counter_[rpc_name]; + VLOG(3) << "WaitBarrier out: " << rpc_name + << " counter: " << barrier_counter_[rpc_name]; } void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { @@ -59,8 +59,11 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { int b = 0; std::unique_lock lock(mutex_); b = ++barrier_counter_[rpc_name]; + VLOG(3) << rpc_name << " barrier_counter: " << b; if (b >= client_num_) { lock.unlock(); + VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for " + << rpc_name; barrier_cond_.notify_all(); lock.lock(); } -- GitLab