提交 1239fce7 编写于 作者: Y Yancey1989

polish sparse update code

上级 e0895e49
......@@ -63,6 +63,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
rpc_server_->RecordSparseVar(invar);
}
}
return true;
......
......@@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() {
t.second = 0;
}
}
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
sparse_vars_.push_back(sparse_var);
}
void RPCServer::ResetSparseVarsRecorder() {
VLOG(3) << "RPCServer reset sparse vars recorder.";
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
......
......@@ -60,7 +60,10 @@ class RPCServer {
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
void RecordSparseVar(framework::Variable* sparse_var);
void ResetSparseVarsRecorder();
protected:
virtual void ShutDownImpl() = 0;
......@@ -74,6 +77,9 @@ class RPCServer {
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
std::vector<framework::Variable*> sparse_vars_;
std::mutex mutex_sparse_var_recorder_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
......
......@@ -146,6 +146,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter();
rpc_service_->ResetSparseVarsRecorder();
} // while(true)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册