提交 56964946 编写于 作者: Y Yancey1989

polish sparse update logic

上级 259e63d4
...@@ -64,13 +64,21 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -64,13 +64,21 @@ bool RequestSendHandler::Handle(const std::string& varname,
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) { if (invar->IsType<framework::SelectedRows>()) {
rpc_server_->RecordSparseVar(invar); std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
} }
} }
return true; return true;
} }
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
......
...@@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler { ...@@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar) override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
}; };
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
......
...@@ -73,19 +73,6 @@ void RPCServer::ResetBarrierCounter() { ...@@ -73,19 +73,6 @@ void RPCServer::ResetBarrierCounter() {
t.second = 0; t.second = 0;
} }
} }
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
sparse_vars_.push_back(sparse_var);
}
void RPCServer::ResetSparseVarsRecorder() {
VLOG(3) << "RPCServer reset sparse vars recorder.";
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
void RPCServer::RegisterRPC(const std::string& rpc_name, void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) { RequestHandler* handler, int thread_num) {
......
...@@ -62,8 +62,6 @@ class RPCServer { ...@@ -62,8 +62,6 @@ class RPCServer {
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter(); void ResetBarrierCounter();
void RecordSparseVar(framework::Variable* sparse_var);
void ResetSparseVarsRecorder();
protected: protected:
virtual void ShutDownImpl() = 0; virtual void ShutDownImpl() = 0;
...@@ -77,9 +75,6 @@ class RPCServer { ...@@ -77,9 +75,6 @@ class RPCServer {
std::atomic<int> cur_cond_; std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_; std::condition_variable rpc_cond_;
std::vector<framework::Variable*> sparse_vars_;
std::mutex mutex_sparse_var_recorder_;
protected: protected:
std::string bind_address_; std::string bind_address_;
std::atomic<int> exit_flag_; std::atomic<int> exit_flag_;
......
...@@ -146,7 +146,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -146,7 +146,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
rpc_service_->ResetSparseVarsRecorder(); // reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
} // while(true) } // while(true)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册