提交 241e980f 编写于 作者: C cristoval

graceful shutdown in ps mode

上级 7be664fa
...@@ -70,6 +70,7 @@ class ParameterServer { ...@@ -70,6 +70,7 @@ class ParameterServer {
handler_(nullptr), handler_(nullptr),
func_graph_(nullptr), func_graph_(nullptr),
sess_(nullptr), sess_(nullptr),
running_(true),
thread_(nullptr) {} thread_(nullptr) {}
~ParameterServer() = default; ~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete; ParameterServer(const ParameterServer &) = delete;
...@@ -106,6 +107,7 @@ class ParameterServer { ...@@ -106,6 +107,7 @@ class ParameterServer {
void InitGrad(const Key &key, const GradPtr &grad); void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key, void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
void Finalize();
void UpdateWeights(); void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key); WeightPtr weight(const Key &key);
...@@ -123,6 +125,7 @@ class ParameterServer { ...@@ -123,6 +125,7 @@ class ParameterServer {
std::unique_ptr<ServerHandler> handler_; std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_; FuncGraphPtr func_graph_;
std::shared_ptr<session::SessionBasic> sess_; std::shared_ptr<session::SessionBasic> sess_;
bool running_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_; std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_; std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
...@@ -261,7 +264,7 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta ...@@ -261,7 +264,7 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
template <typename T> template <typename T>
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) { ::ps::KVPairs<T> *res) {
::ps::Finalize(0, false); ps_->Finalize();
} }
template <typename T> template <typename T>
...@@ -381,11 +384,20 @@ void ParameterServer<T>::InitEmbeddingTable( ...@@ -381,11 +384,20 @@ void ParameterServer<T>::InitEmbeddingTable(
grads_accum_counter_[key] = 0; grads_accum_counter_[key] = 0;
} }
template <typename T>
void ParameterServer<T>::Finalize() {
running_ = false;
apply_grads_cv_.notify_one();
}
template <typename T> template <typename T>
void ParameterServer<T>::UpdateWeights() { void ParameterServer<T>::UpdateWeights() {
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
if (!running_) {
break;
}
for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
Key key = iter->first; Key key = iter->first;
...@@ -550,6 +562,8 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) { ...@@ -550,6 +562,8 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
} }
Init(func_graph); Init(func_graph);
thread_->join(); thread_->join();
::ps::Finalize(0, true);
exit(1);
} }
} // namespace ps } // namespace ps
} // namespace parallel } // namespace parallel
......
...@@ -23,9 +23,8 @@ namespace parallel { ...@@ -23,9 +23,8 @@ namespace parallel {
namespace ps { namespace ps {
void Scheduler::Run() { void Scheduler::Run() {
::ps::Start(0); ::ps::Start(0);
while (true) { ::ps::Finalize(0, true);
sleep(1); exit(1);
}
} }
} // namespace ps } // namespace ps
} // namespace parallel } // namespace parallel
......
...@@ -54,7 +54,7 @@ class Worker { ...@@ -54,7 +54,7 @@ class Worker {
private: private:
Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {}
~Worker() { ::ps::Finalize(0, true); } ~Worker() = default;
Worker(const Worker &) = delete; Worker(const Worker &) = delete;
Worker &operator=(const Worker &) = delete; Worker &operator=(const Worker &) = delete;
...@@ -81,7 +81,6 @@ void Worker<T>::Run() { ...@@ -81,7 +81,6 @@ void Worker<T>::Run() {
MS_LOG(INFO) << "'Worker is already running."; MS_LOG(INFO) << "'Worker is already running.";
return; return;
} }
::ps::Start(0); ::ps::Start(0);
if (!::ps::IsWorker()) { if (!::ps::IsWorker()) {
MS_LOG(EXCEPTION) << "The role is not worker."; MS_LOG(EXCEPTION) << "The role is not worker.";
...@@ -121,7 +120,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : ...@@ -121,7 +120,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
template <typename T> template <typename T>
void Worker<T>::Finalize() { void Worker<T>::Finalize() {
if (running_) {
kv_worker_->Finalize(); kv_worker_->Finalize();
kv_worker_.reset();
running_ = false;
}
} }
template <typename T> template <typename T>
......
...@@ -155,7 +155,7 @@ void WorkerProxy<T>::Finalize() { ...@@ -155,7 +155,7 @@ void WorkerProxy<T>::Finalize() {
kvs.vals.push_back(0.0f); kvs.vals.push_back(0.0f);
Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
obj_->WaitRequest(ts); obj_->WaitRequest(ts);
::ps::Finalize(0, false); ::ps::Finalize(0, true);
} }
template <typename T> template <typename T>
......
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/util.h" #include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/worker.h"
#endif #endif
#if (ENABLE_GE || ENABLE_D) #if (ENABLE_GE || ENABLE_D)
...@@ -949,7 +950,13 @@ void ClearResAtexit() { ...@@ -949,7 +950,13 @@ void ClearResAtexit() {
pynative::ClearPyNativeSession(); pynative::ClearPyNativeSession();
session::ClearPythonParasMap(); session::ClearPythonParasMap();
device::KernelRuntimeManager::Instance().ClearRuntimeResource(); device::KernelRuntimeManager::Instance().ClearRuntimeResource();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (mindspore::parallel::ps::Util::IsParamServerMode()) {
if (parallel::ps::Util::IsRoleOfWorker()) {
parallel::ps::Worker<float>::GetInstance().Finalize();
}
}
#endif
ad::g_k_prims.clear(); ad::g_k_prims.clear();
abstract::ClearPrimEvaluatorMap(); abstract::ClearPrimEvaluatorMap();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册