未验证 提交 259e63d4 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #11248 from panyx0718/dist

Fix sparse vars usage for dist train
......@@ -80,7 +80,6 @@ class RequestHandler {
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
......@@ -113,13 +112,7 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
};
} // namespace detail
......
......@@ -63,10 +63,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
sparse_vars_.push_back(invar);
rpc_server_->RecordSparseVar(invar);
}
}
......
......@@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() {
t.second = 0;
}
}
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
sparse_vars_.push_back(sparse_var);
}
void RPCServer::ResetSparseVarsRecorder() {
VLOG(3) << "RPCServer reset sparse vars recorder.";
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
......
......@@ -60,7 +60,10 @@ class RPCServer {
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
void RecordSparseVar(framework::Variable* sparse_var);
void ResetSparseVarsRecorder();
protected:
virtual void ShutDownImpl() = 0;
......@@ -74,6 +77,9 @@ class RPCServer {
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
std::vector<framework::Variable*> sparse_vars_;
std::mutex mutex_sparse_var_recorder_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
......
......@@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
......@@ -146,18 +143,10 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (framework::Variable *var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter();
rpc_service_->ResetSparseVarsRecorder();
} // while(true)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册