未验证 提交 259e63d4 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #11248 from panyx0718/dist

Fix sparse vars usage for dist train
...@@ -80,7 +80,6 @@ class RequestHandler { ...@@ -80,7 +80,6 @@ class RequestHandler {
} }
framework::ProgramDesc* program() { return program_; } framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; } framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request. // This function processes user's rpc request.
// The implemention is in request_handler_impl. // The implemention is in request_handler_impl.
...@@ -113,13 +112,7 @@ class RequestHandler { ...@@ -113,13 +112,7 @@ class RequestHandler {
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_; grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_; RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
}; };
} // namespace detail } // namespace detail
......
...@@ -63,10 +63,8 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -63,10 +63,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
PADDLE_THROW("sync: Can not find server side var"); PADDLE_THROW("sync: Can not find server side var");
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) { if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_); rpc_server_->RecordSparseVar(invar);
sparse_vars_.push_back(invar);
} }
} }
......
...@@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() { ...@@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() {
t.second = 0; t.second = 0;
} }
} }
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
sparse_vars_.push_back(sparse_var);
}
void RPCServer::ResetSparseVarsRecorder() {
VLOG(3) << "RPCServer reset sparse vars recorder.";
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
void RPCServer::RegisterRPC(const std::string& rpc_name, void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) { RequestHandler* handler, int thread_num) {
......
...@@ -60,7 +60,10 @@ class RPCServer { ...@@ -60,7 +60,10 @@ class RPCServer {
void SetCond(const std::string& rpc_name); void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter(); void ResetBarrierCounter();
void RecordSparseVar(framework::Variable* sparse_var);
void ResetSparseVarsRecorder();
protected: protected:
virtual void ShutDownImpl() = 0; virtual void ShutDownImpl() = 0;
...@@ -74,6 +77,9 @@ class RPCServer { ...@@ -74,6 +77,9 @@ class RPCServer {
std::atomic<int> cur_cond_; std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_; std::condition_variable rpc_cond_;
std::vector<framework::Variable*> sparse_vars_;
std::mutex mutex_sparse_var_recorder_;
protected: protected:
std::string bind_address_; std::string bind_address_;
std::atomic<int> exit_flag_; std::atomic<int> exit_flag_;
......
...@@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
...@@ -146,18 +143,10 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -146,18 +143,10 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (framework::Variable *var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
rpc_service_->ResetSparseVarsRecorder();
} // while(true) } // while(true)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册