提交 2a9c4588 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5812 Add PS context.

Merge pull request !5812 from ZPaC/master-context-for-ps
......@@ -41,8 +41,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())};
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
if (mindspore::parallel::ps::Util::IsRoleOfWorker()) {
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape[axis]);
parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens);
}
......
......@@ -32,11 +32,6 @@ constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
constexpr char kEnvRole[] = "MS_ROLE";
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";
......
......@@ -39,6 +39,7 @@
#include "frontend/parallel/ps/optimizer_info.h"
#include "frontend/parallel/ps/optimizer_info_builder.h"
#include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/ps_context.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
......@@ -741,7 +742,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
return;
}
Init(func_graph);
Util::SetRankId(rank_id_);
PSContext::instance()->SetPSRankId(rank_id_);
thread_->join();
::ps::Finalize(0, true);
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ps/ps_context.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace parallel {
namespace ps {
std::shared_ptr<PSContext> PSContext::instance() {
static std::shared_ptr<PSContext> ps_instance = nullptr;
if (ps_instance == nullptr) {
ps_instance.reset(new (std::nothrow) PSContext());
}
return ps_instance;
}
void PSContext::SetPSEnable(bool enabled) {
ps_enabled_ = enabled;
if (ps_enabled_) {
std::string ms_role = common::GetEnv(kEnvRole);
MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
if (ms_role == kEnvRoleOfWorker) {
is_worker_ = true;
} else if (ms_role == kEnvRoleOfPServer) {
is_pserver_ = true;
} else if (ms_role == kEnvRoleOfScheduler) {
is_sched_ = true;
} else {
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
}
} else {
MS_LOG(INFO) << "PS mode is disabled.";
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
}
}
bool PSContext::is_ps_enabled() const { return ps_enabled_; }
void PSContext::Reset() {
ps_enabled_ = false;
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
}
std::string PSContext::ms_role() const {
if (is_worker_) {
return kEnvRoleOfWorker;
} else if (is_pserver_) {
return kEnvRoleOfPServer;
} else if (is_sched_) {
return kEnvRoleOfScheduler;
} else {
return kEnvRoleOfNotPS;
}
}
bool PSContext::is_role_worker() const { return is_worker_; }
bool PSContext::is_role_pserver() const { return is_pserver_; }
bool PSContext::is_role_sched() const { return is_sched_; }
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
int PSContext::ps_rank_id() const { return rank_id_; }
} // namespace ps
} // namespace parallel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
#include <string>
#include <memory>
namespace mindspore {
namespace parallel {
namespace ps {
constexpr char kEnvRole[] = "MS_ROLE";
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
class PSContext {
public:
~PSContext() = default;
PSContext(PSContext const &) = delete;
PSContext &operator=(const PSContext &) = delete;
static std::shared_ptr<PSContext> instance();
void SetPSEnable(bool enabled);
bool is_ps_enabled() const;
void Reset();
std::string ms_role() const;
bool is_role_worker() const;
bool is_role_pserver() const;
bool is_role_sched() const;
void SetPSRankId(int rank_id);
int ps_rank_id() const;
private:
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
bool is_sched_;
int rank_id_;
};
} // namespace ps
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
......@@ -16,7 +16,9 @@
#include "frontend/parallel/ps/util.h"
#include <unordered_map>
#include <vector>
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/ps_context.h"
#include "utils/ms_utils.h"
namespace mindspore {
......@@ -45,34 +47,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizer_nodes{
{3, kSparseFtrlOp},
};
bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); }
bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); }
bool Util::IsRoleOfWorker() {
auto role = common::GetEnv(kEnvRole);
if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) {
return true;
} else {
return false;
}
}
bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); }
bool Util::IsRoleOfPServer() {
auto role = common::GetEnv(kEnvRole);
if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) {
return true;
} else {
return false;
}
}
bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); }
bool Util::IsRoleOfScheduler() {
auto role = common::GetEnv(kEnvRole);
if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) {
return true;
} else {
return false;
}
}
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); }
void Util::SetInternalEnvVar() {
if (IsParamServerMode()) {
......@@ -163,10 +144,6 @@ std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int serve
return shard_dims;
}
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
int Util::GetRankId() { return rank_id_; }
void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
const size_t first_dim_size, const size_t outer_dim_size,
mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
......
......@@ -40,8 +40,6 @@ class Util {
static bool is_optimizer(std::string name);
static int LocalShard(int first_dim, int rank_id, int server_num);
static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num);
static void SetRankId(int rank_id);
static int GetRankId();
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
const size_t first_dim_size, const size_t outer_dim_size,
mindspore::kernel::SparseGradient<int> *unique_sparse_grad);
......
......@@ -27,6 +27,7 @@
#include "ps/ps.h"
#include "frontend/parallel/ps/util.h"
#include "backend/kernel_compiler/common_utils.h"
#include "frontend/parallel/ps/ps_context.h"
namespace mindspore {
namespace parallel {
......@@ -43,7 +44,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id)
: Worker(app_id, customer_id) {
server_num_ = ::ps::NumServers();
Util::SetRankId(::ps::MyRank());
PSContext::instance()->SetPSRankId(::ps::MyRank());
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
......
......@@ -36,6 +36,7 @@
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "frontend/parallel/ps/util.h"
#endif
#include "frontend/parallel/ps/ps_context.h"
namespace py = pybind11;
using EnvInstance = mindspore::EnvInstance;
......@@ -49,6 +50,7 @@ using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
using PSContext = mindspore::parallel::ps::PSContext;
// Interface with python
PYBIND11_MODULE(_c_expression, m) {
......@@ -276,9 +278,15 @@ PYBIND11_MODULE(_c_expression, m) {
"Finalize gpu collective communication mode.");
#endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
(void)m.def("get_ps_mode_rank", &mindspore::parallel::ps::Util::GetRankId, "Get Worker and PServer rank id.");
#endif
(void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
.def_static("get_instance", &PSContext::instance, "Get PS context instance.")
.def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
.def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.")
.def("reset", &PSContext::Reset, "Reset PS context attributes.")
.def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
.def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
.def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())
......
......@@ -15,7 +15,6 @@
# limitations under the License.
# ============================================================================
"""Providing interface methods."""
import os
import types
from collections import OrderedDict
from functools import wraps
......@@ -25,6 +24,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
from ..parallel._ps_context import _is_role_pserver
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
......@@ -469,7 +469,7 @@ class _Executor:
return self._executor.has_compiled(phase)
def __call__(self, obj, *args, phase='predict'):
if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER":
if context.get_context("precompile_only") or _is_role_pserver():
return None
return self.run(obj, *args, phase=phase)
......
......@@ -22,6 +22,7 @@ from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
__all__ = ['Parameter', 'ParameterTuple']
......@@ -168,8 +169,13 @@ class Parameter(MetaTensor):
"""For parse check."""
def set_param_ps(self, init_in_server=False):
self.is_param_ps = True
self.init_in_server = init_in_server
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
self.is_param_ps = True
self.init_in_server = init_in_server
else:
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
1. set_ps_context(enable_ps=True) \
2. export MS_ROLE environment variable.")
@property
......
......@@ -14,7 +14,7 @@
# ============================================================================
"""comm_helper"""
import os
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._hccl_management import load_lib as hccl_load_lib
_HCCL_AVAILABLE = False
......@@ -44,7 +44,6 @@ else:
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
MS_ROLE = os.getenv("MS_ROLE")
class Backend:
"""
......@@ -113,7 +112,7 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
group = None
if "group" in kargs.keys():
......@@ -154,7 +153,7 @@ def _get_rank_helper(group, backend):
Integer. The local rank id of the calling process.
"""
rank_id = None
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
rank_id = 0
return rank_id
if backend == Backend.HCCL:
......@@ -213,7 +212,7 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size = None
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
size = 1
return size
if backend == Backend.HCCL:
......
......@@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""Communication management API"""
import os
from mindspore import context
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
......@@ -29,7 +29,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")
MS_ROLE = os.getenv("MS_ROLE")
def _get_group(group):
......@@ -61,7 +60,7 @@ def init(backend_name=None):
RuntimeError: If device target is invalid.
RuntimeError: If backend is invalid or distributed init fails.
"""
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
return
if backend_name is None:
device_target = context.get_context("device_target")
......
......@@ -26,9 +26,11 @@ from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
'get_ps_context', 'reset_ps_context']
GRAPH_MODE = 0
PYNATIVE_MODE = 1
......@@ -569,3 +571,58 @@ class ParallelMode:
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
@args_type_check(enable_ps=bool)
def set_ps_context(**kwargs):
"""
Set parameter server training mode context.
Note:
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED represents the scheduler,
MS_WORKER represents the worker,
MS_PSERVER represents the Server
Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.
Raises:
ValueError: If input key is not the attribute in parameter server training mode context.
Examples:
>>> context.set_ps_context(enable_ps=True)
"""
_set_ps_context(**kwargs)
def get_ps_context(attr_key):
"""
Get parameter server training mode context attribute value according to the key.
Args:
attr_key (str): The key of the attribute.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
return _get_ps_context(attr_key)
def reset_ps_context():
"""
Reset parameter server training mode context attributes to the default values:
- enable_ps: False.
"""
_reset_ps_context()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context for parameter server training mode"""
from mindspore._c_expression import PSContext
_ps_context = None
def ps_context():
"""
Get the global _ps_context, if it is not created, create a new one.
Returns:
_ps_context, the global parameter server training mode context.
"""
global _ps_context
if _ps_context is None:
_ps_context = PSContext.get_instance()
return _ps_context
_set_ps_context_func_map = {
"enable_ps": ps_context().set_ps_enable
}
_get_ps_context_func_map = {
"enable_ps": ps_context().is_ps_enabled
}
def _get_ps_mode_rank():
ps_rank = ps_context().ps_rank_id()
if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not enabled yet.")
return ps_rank
def _set_ps_context(**kwargs):
"""
Set parameter server training mode context.
Note:
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED represents the scheduler,
MS_WORKER represents the worker,
MS_PSERVER represents the Server
Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.
Raises:
ValueError: If input key is not the attribute in parameter server training mode context.
Examples:
>>> context.set_ps_context(enable_ps=True)
"""
for key, value in kwargs.items():
if key not in _set_ps_context_func_map:
raise ValueError("Set PS context keyword %s is not recognized!" % key)
set_func = _set_ps_context_func_map[key]
set_func(value)
def _get_ps_context(attr_key):
"""
Get parameter server training mode context attribute value according to the key.
Args:
attr_key (str): The key of the attribute.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
if key not in _get_ps_context_func_map:
raise ValueError("Get PS context keyword %s is not recognized!" % key)
get_func = _get_ps_context_func_map[attr_key]
get_func(attr_key)
def _reset_ps_context():
"""
Reset parameter server training mode context attributes to the default values:
- enable_ps: False.
"""
ps_context().reset()
def _is_role_worker():
return ps_context().is_role_worker()
def _is_role_pserver():
return ps_context().is_role_pserver()
def _is_role_sched():
return ps_context().is_role_sched()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for parameter server training mode"""
from mindspore._c_expression import get_ps_mode_rank
def _get_ps_mode_rank():
ps_rank = get_ps_mode_rank()
if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not launched yet.")
return ps_rank
......@@ -24,6 +24,7 @@ from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from ._callback import Callback, set_cur_net
......@@ -280,8 +281,7 @@ class ModelCheckpoint(Callback):
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
if os.getenv("MS_ROLE") == "MS_PSERVER":
from mindspore.parallel._ps_utils import _get_ps_mode_rank
if _is_role_pserver():
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
......
......@@ -27,6 +27,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
......@@ -378,8 +379,7 @@ class Model:
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
epoch = 1
# build callback list
......@@ -516,7 +516,7 @@ class Model:
self._loss_scale_manager.update_loss_scale(overflow)
list_callback.step_end(run_context)
if os.getenv("MS_ROLE") == "MS_PSERVER":
if _is_role_pserver():
os._exit(0)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
......
......@@ -70,6 +70,7 @@ if __name__ == '__main__':
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
context.set_ps_context(enable_ps=True)
if args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
......
......@@ -14,7 +14,6 @@
# ============================================================================
"""Model."""
import math
import os
from collections.abc import Iterable
import numpy as np
......@@ -405,9 +404,6 @@ class Model:
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
epoch = 1
# build callback list
with _CallbackManager(callbacks) as list_callback:
......
......@@ -118,6 +118,7 @@ if __name__ == "__main__":
wide_deep_config.argparse_init()
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
context.set_ps_context(enable_ps=True)
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size())
......
......@@ -26,6 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
parser = argparse.ArgumentParser(description="test_sparse_embedding")
parser.add_argument("--device_target", type=str, default="Ascend")
......@@ -34,6 +35,7 @@ device_target = args.device_target
context.set_context(
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
)
context.set_ps_context(enable_ps=True)
def fc_with_initialize(input_channels, out_channels):
......@@ -81,7 +83,7 @@ def do_sparse_embedding(ps=False):
for _ in range(epoch):
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
label = Tensor(np.random.randint(0, 9, (32), np.int32))
if envs.get("MS_ROLE") == "MS_PSERVER":
if _is_role_pserver():
train_network(data, label)
sys.exit()
else:
......@@ -96,10 +98,10 @@ if __name__ == "__main__":
np.random.seed(0)
ps_loss = do_sparse_embedding(True)
if envs.get("MS_ROLE") == "MS_WORKER":
envs["MS_ROLE"] = ""
if _is_role_worker():
context.reset_ps_context()
np.random.seed(0)
no_ps_loss = do_sparse_embedding()
envs["MS_ROLE"] = "MS_WORKER"
context.set_ps_context(enable_ps=True)
assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6)
......@@ -35,6 +35,7 @@ args, _ = parser.parse_known_args()
device_target = args.device_target
dataset_path = args.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_ps_context(enable_ps=True)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import sys
import argparse
import numpy as np
......@@ -22,6 +23,7 @@ from mindspore.common.initializer import TruncatedNormal
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.communication.management import init, get_group_size
from mindspore.parallel._ps_context import _is_role_pserver
# from resnet import resnet50
parser = argparse.ArgumentParser(description="test_ps_lenet")
......@@ -29,6 +31,7 @@ parser.add_argument("--device_target", type=str, default="Ascend")
args, _ = parser.parse_known_args()
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_ps_context(enable_ps=True)
if device_target == "GPU":
init()
......@@ -106,6 +109,10 @@ if __name__ == "__main__":
for _ in range(epoch):
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
if _is_role_pserver():
train_network(data, label)
sys.exit()
else:
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册