提交 fc8bd0dd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5933 Fix pserver error and optimize worker and server log.

Merge pull request !5933 from ZPaC/master-fix-error-when-pserver-finish-training
......@@ -736,7 +736,9 @@ void ParameterServer<T>::SyncEmbeddingTables() {
template <typename T>
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
::ps::Start(0);
MS_LOG(INFO) << "PServer connected successfully.";
if (!::ps::IsServer()) {
std::cout << "This is not ther Server" << std::endl;
return;
......@@ -744,7 +746,9 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
Init(func_graph);
PSContext::instance()->SetPSRankId(rank_id_);
thread_->join();
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
::ps::Finalize(0, true);
MS_LOG(INFO) << "PServer finalized successfully.";
}
} // namespace ps
} // namespace parallel
......
......@@ -86,7 +86,9 @@ void Worker<T>::Run() {
MS_LOG(INFO) << "'Worker is already running.";
return;
}
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
::ps::Start(0);
MS_LOG(INFO) << "Worker connected successfully.";
if (!::ps::IsWorker()) {
MS_LOG(EXCEPTION) << "The role is not worker.";
}
......@@ -176,9 +178,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
template <typename T>
void Worker<T>::Finalize() {
if (running_) {
MS_LOG(INFO) << "Worker starts finalizing...";
kv_worker_->Finalize();
kv_worker_.reset();
running_ = false;
MS_LOG(INFO) << "Worker finalized successfully.";
}
}
......@@ -315,7 +319,7 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, tensor::Tenso
size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
return;
}
bool init_in_server = false;
......
......@@ -36,6 +36,7 @@ class LossCallBack(Callback):
Note:
If per_print_times is 0, do NOT print loss.
If this process is MS_PSERVER role, do not run callbacks.
Args:
per_print_times (int): Print loss every times. Default: 1.
......@@ -50,6 +51,8 @@ class LossCallBack(Callback):
def step_end(self, run_context):
"""Monitor the loss in training."""
cb_params = run_context.original_args()
if cb_params.net_outputs is None:
return
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册