diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 2a32904ab055f076b805f0212943aaf4e8b016ac..429cf59c3c1130accf3e391170455a3bb306b873 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -70,6 +70,7 @@ class ParameterServer { handler_(nullptr), func_graph_(nullptr), sess_(nullptr), + running_(true), thread_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; @@ -106,6 +107,7 @@ class ParameterServer { void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, const std::shared_ptr>>> &shapes); + void Finalize(); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); @@ -123,6 +125,7 @@ class ParameterServer { std::unique_ptr handler_; FuncGraphPtr func_graph_; std::shared_ptr sess_; + bool running_; std::unordered_map> optimizers_; std::unordered_map optim_inputs_shape_; @@ -261,7 +264,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta template void ParameterServer::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - ::ps::Finalize(0, false); + ps_->Finalize(); } template @@ -381,11 +384,20 @@ void ParameterServer::InitEmbeddingTable( grads_accum_counter_[key] = 0; } +template +void ParameterServer::Finalize() { + running_ = false; + apply_grads_cv_.notify_one(); +} + template void ParameterServer::UpdateWeights() { while (true) { std::unique_lock lock(mutex_); - apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; }); + if (!running_) { + break; + } for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { Key key = iter->first; @@ -550,6 +562,8 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { } Init(func_graph); thread_->join(); + ::ps::Finalize(0, true); + exit(1); } } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc index 274b7259b0909ca90878b432e873a4c356c42b46..04c259487fa73532ee761d518559187a9dd2a85a 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc @@ -23,9 +23,8 @@ namespace parallel { namespace ps { void Scheduler::Run() { ::ps::Start(0); - while (true) { - sleep(1); - } + ::ps::Finalize(0, true); + exit(1); } } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index abdc046ffe08a02d817a29ae2b1ee1b3f74ad398..d7f0bb6df52ce9d5121b7f206c8df058c0db9c68 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -54,7 +54,7 @@ class Worker { private: Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} - ~Worker() { ::ps::Finalize(0, true); } + ~Worker() = default; Worker(const Worker &) = delete; Worker &operator=(const Worker &) = delete; @@ -81,7 +81,6 @@ void Worker::Run() { MS_LOG(INFO) << "'Worker is already running."; return; } - ::ps::Start(0); if (!::ps::IsWorker()) { MS_LOG(EXCEPTION) << "The role is not worker."; @@ -121,7 +120,11 @@ void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : template void Worker::Finalize() { - kv_worker_->Finalize(); + if (running_) { + kv_worker_->Finalize(); + kv_worker_.reset(); + running_ = false; + } } template diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index c8bd27c067a56a961d89826eb3d3deb6ddd8c5f4..6ac7f6322daf8751ac737337dacb2af3fd2f183b 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -155,7 +155,7 @@ void WorkerProxy::Finalize() { kvs.vals.push_back(0.0f); Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); obj_->WaitRequest(ts); - ::ps::Finalize(0, false); + ::ps::Finalize(0, true); } template diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index b1c3f2db036634bac73310e9a9ca5df407868ae7..dee864d085ec3d22e0c86da3829e9751103b93e2 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -45,6 +45,7 @@ #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/util.h" +#include "frontend/parallel/ps/worker.h" #endif #if (ENABLE_GE || ENABLE_D) @@ -949,7 +950,13 @@ void ClearResAtexit() { pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (mindspore::parallel::ps::Util::IsParamServerMode()) { + if (parallel::ps::Util::IsRoleOfWorker()) { + parallel::ps::Worker::GetInstance().Finalize(); + } + } +#endif ad::g_k_prims.clear(); abstract::ClearPrimEvaluatorMap();