提交 84220765 编写于 作者: Q Qiao Longfei

refine code, add more log

上级 c750be6d
......@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async
if (!sync_mode_) {
VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE || varname == COMPLETE_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
......@@ -95,21 +100,25 @@ bool RequestGetHandler::Handle(const std::string& varname,
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distributed_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
*outvar = scope_->FindVar(varname);
PADDLE_THROW(
"async mode should not send FETCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
if (enable_dc_asgd_) {
// NOTE: the format is determined by distributed_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
*outvar = scope_->FindVar(varname);
}
return true;
}
......
......@@ -43,15 +43,15 @@ void RPCServer::SavePort() const {
}
void RPCServer::WaitBarrier(const std::string& rpc_name) {
VLOG(3) << "WaitBarrier: " << rpc_name;
VLOG(3) << "WaitBarrier in: " << rpc_name;
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] {
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load());
});
VLOG(3) << "batch_barrier_: " << rpc_name << " "
<< barrier_counter_[rpc_name];
VLOG(3) << "WaitBarrier out: " << rpc_name
<< " counter: " << barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
......@@ -59,8 +59,11 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
int b = 0;
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
VLOG(3) << rpc_name << " barrier_counter: " << b;
if (b >= client_num_) {
lock.unlock();
VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for "
<< rpc_name;
barrier_cond_.notify_all();
lock.lock();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册