diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 913ae76b38dc663d6fb4102f795ac713fd8a6bdf..2ed468331899e0e81d0285ad12937c31e5e50075 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname, // Async if (!sync_mode_) { VLOG(3) << "async process var: " << varname; + if (varname == BATCH_BARRIER_MESSAGE || varname == COMPLETE_MESSAGE) { + PADDLE_THROW( + "async mode should not recv BATCH_BARRIER_MESSAGE or " + "COMPLETE_MESSAGE"); + } try { executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); @@ -95,21 +100,25 @@ bool RequestGetHandler::Handle(const std::string& varname, } } else { if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { - if (enable_dc_asgd_) { - // NOTE: the format is determined by distributed_transpiler.py - std::string param_bak_name = - string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); - VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; - auto var = scope_->FindVar(varname); - auto t_orig = var->Get(); - auto param_bak = scope_->Var(param_bak_name); - auto t = param_bak->GetMutable(); - t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); - VLOG(3) << "copying " << varname << " to " << param_bak_name; - framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); - } - *outvar = scope_->FindVar(varname); + PADDLE_THROW( + "async mode should not send FETCH_BARRIER_MESSAGE or " + "COMPLETE_MESSAGE"); + } + + if (enable_dc_asgd_) { + // NOTE: the format is determined by distributed_transpiler.py + std::string param_bak_name = + string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); + VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; + auto var = scope_->FindVar(varname); + auto t_orig = var->Get(); + auto param_bak = scope_->Var(param_bak_name); + auto t = param_bak->GetMutable(); + t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); + VLOG(3) << "copying " << varname << " to " << param_bak_name; + framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); } + *outvar = scope_->FindVar(varname); } return true; } diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index baf6b73b0debea57de670e3033c2d411f5759c79..90733fd09096bb1a4767fdd79dc3e35fdf499cab 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -43,15 +43,15 @@ void RPCServer::SavePort() const { } void RPCServer::WaitBarrier(const std::string& rpc_name) { - VLOG(3) << "WaitBarrier: " << rpc_name; + VLOG(3) << "WaitBarrier in: " << rpc_name; std::unique_lock lock(this->mutex_); barrier_cond_.wait(lock, [this, &rpc_name] { return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || exit_flag_.load()); }); - VLOG(3) << "batch_barrier_: " << rpc_name << " " - << barrier_counter_[rpc_name]; + VLOG(3) << "WaitBarrier out: " << rpc_name + << " counter: " << barrier_counter_[rpc_name]; } void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { @@ -59,8 +59,11 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { int b = 0; std::unique_lock lock(mutex_); b = ++barrier_counter_[rpc_name]; + VLOG(3) << rpc_name << " barrier_counter: " << b; if (b >= client_num_) { lock.unlock(); + VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for " + << rpc_name; barrier_cond_.notify_all(); lock.lock(); }