未验证 提交 3206b179 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13224 from jacquesqiao/cherry-pick-async-handle-complete-message

fix async mode handle COMPLETE_MESSAGE (#13212)
...@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Async
if (!sync_mode_) {
rpc_server_->Profiler().OneStep();
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true;
}
// Sync // Sync
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
...@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG(3) << "sync: recv complete message"; VLOG(3) << "sync: recv complete message";
rpc_server_->Complete(); rpc_server_->Complete();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; // Async
rpc_server_->WaitCond(kRequestSend); if (!sync_mode_) {
VLOG(3) << "sync: processing received var: " << varname; VLOG(3) << "async process var: " << varname;
rpc_server_->Profiler().OneStep();
if (invar == nullptr) { try {
LOG(FATAL) << "sync: Can not find server side var: " << varname; executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
return false; scope);
} } catch (std::exception& e) {
if (invar->IsType<framework::SelectedRows>()) { LOG(ERROR) << "async: run sub program error " << e.what();
std::unique_lock<std::mutex> lock(mutex_sparse_vars_); return false;
sparse_vars_.push_back(invar); }
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname;
if (invar == nullptr) {
LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
} }
} }
return true; return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册