提交 241e980f 编写于 作者: C cristoval

graceful shutdown in ps mode

上级 7be664fa
......@@ -70,6 +70,7 @@ class ParameterServer {
handler_(nullptr),
func_graph_(nullptr),
sess_(nullptr),
running_(true),
thread_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
......@@ -106,6 +107,7 @@ class ParameterServer {
void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
void Finalize();
void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key);
......@@ -123,6 +125,7 @@ class ParameterServer {
std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_;
std::shared_ptr<session::SessionBasic> sess_;
bool running_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
......@@ -261,7 +264,7 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
template <typename T>
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
::ps::Finalize(0, false);
ps_->Finalize();
}
template <typename T>
......@@ -381,11 +384,20 @@ void ParameterServer<T>::InitEmbeddingTable(
grads_accum_counter_[key] = 0;
}
template <typename T>
void ParameterServer<T>::Finalize() {
running_ = false;
apply_grads_cv_.notify_one();
}
template <typename T>
void ParameterServer<T>::UpdateWeights() {
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); });
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
if (!running_) {
break;
}
for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
Key key = iter->first;
......@@ -550,6 +562,8 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
}
Init(func_graph);
thread_->join();
::ps::Finalize(0, true);
exit(1);
}
} // namespace ps
} // namespace parallel
......
......@@ -23,9 +23,8 @@ namespace parallel {
namespace ps {
void Scheduler::Run() {
::ps::Start(0);
while (true) {
sleep(1);
}
::ps::Finalize(0, true);
exit(1);
}
} // namespace ps
} // namespace parallel
......
......@@ -54,7 +54,7 @@ class Worker {
private:
Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {}
~Worker() { ::ps::Finalize(0, true); }
~Worker() = default;
Worker(const Worker &) = delete;
Worker &operator=(const Worker &) = delete;
......@@ -81,7 +81,6 @@ void Worker<T>::Run() {
MS_LOG(INFO) << "'Worker is already running.";
return;
}
::ps::Start(0);
if (!::ps::IsWorker()) {
MS_LOG(EXCEPTION) << "The role is not worker.";
......@@ -121,7 +120,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
template <typename T>
void Worker<T>::Finalize() {
kv_worker_->Finalize();
if (running_) {
kv_worker_->Finalize();
kv_worker_.reset();
running_ = false;
}
}
template <typename T>
......
......@@ -155,7 +155,7 @@ void WorkerProxy<T>::Finalize() {
kvs.vals.push_back(0.0f);
Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
obj_->WaitRequest(ts);
::ps::Finalize(0, false);
::ps::Finalize(0, true);
}
template <typename T>
......
......@@ -45,6 +45,7 @@
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/worker.h"
#endif
#if (ENABLE_GE || ENABLE_D)
......@@ -949,7 +950,13 @@ void ClearResAtexit() {
pynative::ClearPyNativeSession();
session::ClearPythonParasMap();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (mindspore::parallel::ps::Util::IsParamServerMode()) {
if (parallel::ps::Util::IsRoleOfWorker()) {
parallel::ps::Worker<float>::GetInstance().Finalize();
}
}
#endif
ad::g_k_prims.clear();
abstract::ClearPrimEvaluatorMap();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册