提交 84220765 编写于 作者: Q Qiao Longfei

refine code, add more log

上级 c750be6d
...@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async // Async
if (!sync_mode_) { if (!sync_mode_) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE || varname == COMPLETE_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
try { try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope); scope);
...@@ -95,21 +100,25 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -95,21 +100,25 @@ bool RequestGetHandler::Handle(const std::string& varname,
} }
} else { } else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) { PADDLE_THROW(
// NOTE: the format is determined by distributed_transpiler.py "async mode should not send FETCH_BARRIER_MESSAGE or "
std::string param_bak_name = "COMPLETE_MESSAGE");
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); }
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname); if (enable_dc_asgd_) {
auto t_orig = var->Get<framework::LoDTensor>(); // NOTE: the format is determined by distributed_transpiler.py
auto param_bak = scope_->Var(param_bak_name); std::string param_bak_name =
auto t = param_bak->GetMutable<framework::LoDTensor>(); string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
VLOG(3) << "copying " << varname << " to " << param_bak_name; auto var = scope_->FindVar(varname);
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); auto t_orig = var->Get<framework::LoDTensor>();
} auto param_bak = scope_->Var(param_bak_name);
*outvar = scope_->FindVar(varname); auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
} }
*outvar = scope_->FindVar(varname);
} }
return true; return true;
} }
......
...@@ -43,15 +43,15 @@ void RPCServer::SavePort() const { ...@@ -43,15 +43,15 @@ void RPCServer::SavePort() const {
} }
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
VLOG(3) << "WaitBarrier: " << rpc_name; VLOG(3) << "WaitBarrier in: " << rpc_name;
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] { barrier_cond_.wait(lock, [this, &rpc_name] {
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load()); exit_flag_.load());
}); });
VLOG(3) << "batch_barrier_: " << rpc_name << " " VLOG(3) << "WaitBarrier out: " << rpc_name
<< barrier_counter_[rpc_name]; << " counter: " << barrier_counter_[rpc_name];
} }
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
...@@ -59,8 +59,11 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { ...@@ -59,8 +59,11 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
int b = 0; int b = 0;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name]; b = ++barrier_counter_[rpc_name];
VLOG(3) << rpc_name << " barrier_counter: " << b;
if (b >= client_num_) { if (b >= client_num_) {
lock.unlock(); lock.unlock();
VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for "
<< rpc_name;
barrier_cond_.notify_all(); barrier_cond_.notify_all();
lock.lock(); lock.lock();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册