提交 fc8bd0dd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5933 Fix pserver error and optimize worker and server log.

Merge pull request !5933 from ZPaC/master-fix-error-when-pserver-finish-training
...@@ -736,7 +736,9 @@ void ParameterServer<T>::SyncEmbeddingTables() { ...@@ -736,7 +736,9 @@ void ParameterServer<T>::SyncEmbeddingTables() {
template <typename T> template <typename T>
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) { void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
::ps::Start(0); ::ps::Start(0);
MS_LOG(INFO) << "PServer connected successfully.";
if (!::ps::IsServer()) { if (!::ps::IsServer()) {
std::cout << "This is not ther Server" << std::endl; std::cout << "This is not ther Server" << std::endl;
return; return;
...@@ -744,7 +746,9 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) { ...@@ -744,7 +746,9 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
Init(func_graph); Init(func_graph);
PSContext::instance()->SetPSRankId(rank_id_); PSContext::instance()->SetPSRankId(rank_id_);
thread_->join(); thread_->join();
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
::ps::Finalize(0, true); ::ps::Finalize(0, true);
MS_LOG(INFO) << "PServer finalized successfully.";
} }
} // namespace ps } // namespace ps
} // namespace parallel } // namespace parallel
......
...@@ -86,7 +86,9 @@ void Worker<T>::Run() { ...@@ -86,7 +86,9 @@ void Worker<T>::Run() {
MS_LOG(INFO) << "'Worker is already running."; MS_LOG(INFO) << "'Worker is already running.";
return; return;
} }
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
::ps::Start(0); ::ps::Start(0);
MS_LOG(INFO) << "Worker connected successfully.";
if (!::ps::IsWorker()) { if (!::ps::IsWorker()) {
MS_LOG(EXCEPTION) << "The role is not worker."; MS_LOG(EXCEPTION) << "The role is not worker.";
} }
...@@ -176,9 +178,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : ...@@ -176,9 +178,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
template <typename T> template <typename T>
void Worker<T>::Finalize() { void Worker<T>::Finalize() {
if (running_) { if (running_) {
MS_LOG(INFO) << "Worker starts finalizing...";
kv_worker_->Finalize(); kv_worker_->Finalize();
kv_worker_.reset(); kv_worker_.reset();
running_ = false; running_ = false;
MS_LOG(INFO) << "Worker finalized successfully.";
} }
} }
...@@ -315,7 +319,7 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, tensor::Tenso ...@@ -315,7 +319,7 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, tensor::Tenso
size_t param_key = GetParamKey(param_name); size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) { if (param_key == kInvalidKey) {
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
return; return;
} }
bool init_in_server = false; bool init_in_server = false;
......
...@@ -36,6 +36,7 @@ class LossCallBack(Callback): ...@@ -36,6 +36,7 @@ class LossCallBack(Callback):
Note: Note:
If per_print_times is 0, do NOT print loss. If per_print_times is 0, do NOT print loss.
If this process is MS_PSERVER role, do not run callbacks.
Args: Args:
per_print_times (int): Print loss every times. Default: 1. per_print_times (int): Print loss every times. Default: 1.
...@@ -50,6 +51,8 @@ class LossCallBack(Callback): ...@@ -50,6 +51,8 @@ class LossCallBack(Callback):
def step_end(self, run_context): def step_end(self, run_context):
"""Monitor the loss in training.""" """Monitor the loss in training."""
cb_params = run_context.original_args() cb_params = run_context.original_args()
if cb_params.net_outputs is None:
return
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num cur_num = cb_params.cur_step_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册