diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index ea22397a3ee65dba6e2f7a43caa01cc88e8617ba..cfc22b1a075b9d77b200b42e40776fec6c8a4d63 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -41,8 +41,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; std::vector lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())}; - const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); - if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { + if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape[axis]); parallel::ps::Worker::GetInstance().InitPSEmbeddingTable(keys, values, lens); } diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h index a021dff1ad8c4468c8c377fd0acd867032e01810..3678c6a1c113fe6e9e90f00ed18893c38c2b7511 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/common.h +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -32,11 +32,6 @@ constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM"; constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST"; constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT"; -constexpr char kEnvRole[] = "MS_ROLE"; -constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; -constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; -constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; - constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE"; constexpr char kDmlcInterface[] = "DMLC_INTERFACE"; constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER"; diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 03c0b768ca3d0e56787ddaa23a70621e65182df5..831c7243eab5e81e6979bf8219131aa6a552360c 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -39,6 +39,7 @@ #include "frontend/parallel/ps/optimizer_info.h" #include "frontend/parallel/ps/optimizer_info_builder.h" #include "frontend/parallel/ps/util.h" +#include "frontend/parallel/ps/ps_context.h" #include "runtime/device/cpu/kernel_select_cpu.h" #include "utils/ms_context.h" #include "backend/kernel_compiler/kernel.h" @@ -741,7 +742,7 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { return; } Init(func_graph); - Util::SetRankId(rank_id_); + PSContext::instance()->SetPSRankId(rank_id_); thread_->join(); ::ps::Finalize(0, true); } diff --git a/mindspore/ccsrc/frontend/parallel/ps/ps_context.cc b/mindspore/ccsrc/frontend/parallel/ps/ps_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..054ca585e37178089123d5082cc5ecbc583bcc0b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/ps_context.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/ps_context.h" +#include "utils/log_adapter.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace parallel { +namespace ps { +std::shared_ptr PSContext::instance() { + static std::shared_ptr ps_instance = nullptr; + if (ps_instance == nullptr) { + ps_instance.reset(new (std::nothrow) PSContext()); + } + return ps_instance; +} + +void PSContext::SetPSEnable(bool enabled) { + ps_enabled_ = enabled; + if (ps_enabled_) { + std::string ms_role = common::GetEnv(kEnvRole); + MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role; + if (ms_role == kEnvRoleOfWorker) { + is_worker_ = true; + } else if (ms_role == kEnvRoleOfPServer) { + is_pserver_ = true; + } else if (ms_role == kEnvRoleOfScheduler) { + is_sched_ = true; + } else { + MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid."; + } + } else { + MS_LOG(INFO) << "PS mode is disabled."; + is_worker_ = false; + is_pserver_ = false; + is_sched_ = false; + } +} + +bool PSContext::is_ps_enabled() const { return ps_enabled_; } + +void PSContext::Reset() { + ps_enabled_ = false; + is_worker_ = false; + is_pserver_ = false; + is_sched_ = false; +} + +std::string PSContext::ms_role() const { + if (is_worker_) { + return kEnvRoleOfWorker; + } else if (is_pserver_) { + return kEnvRoleOfPServer; + } else if (is_sched_) { + return kEnvRoleOfScheduler; + } else { + return kEnvRoleOfNotPS; + } +} + +bool PSContext::is_role_worker() const { return is_worker_; } + +bool PSContext::is_role_pserver() const { return is_pserver_; } + +bool PSContext::is_role_sched() const { return is_sched_; } + +void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } + +int PSContext::ps_rank_id() const { return rank_id_; } +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/ps_context.h b/mindspore/ccsrc/frontend/parallel/ps/ps_context.h new file mode 100644 index 0000000000000000000000000000000000000000..bb9f734d6b3a5d8c8a72cc06dae0450b410438e9 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/ps_context.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_ + +#include +#include + +namespace mindspore { +namespace parallel { +namespace ps { +constexpr char kEnvRole[] = "MS_ROLE"; +constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; +constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; +constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; +constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; + +class PSContext { + public: + ~PSContext() = default; + PSContext(PSContext const &) = delete; + PSContext &operator=(const PSContext &) = delete; + static std::shared_ptr instance(); + + void SetPSEnable(bool enabled); + bool is_ps_enabled() const; + void Reset(); + std::string ms_role() const; + bool is_role_worker() const; + bool is_role_pserver() const; + bool is_role_sched() const; + void SetPSRankId(int rank_id); + int ps_rank_id() const; + + private: + PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} + bool ps_enabled_; + bool is_worker_; + bool is_pserver_; + bool is_sched_; + int rank_id_; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc index ae698a7bec6214c42c1f83f5f3a528995cb2645e..123ee7746317914fa2cc1515607650f0ddbfa520 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -16,7 +16,9 @@ #include "frontend/parallel/ps/util.h" #include +#include #include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/ps_context.h" #include "utils/ms_utils.h" namespace mindspore { @@ -45,34 +47,13 @@ std::unordered_map Util::id_to_optimizer_nodes{ {3, kSparseFtrlOp}, }; -bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } +bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); } -bool Util::IsRoleOfWorker() { - auto role = common::GetEnv(kEnvRole); - if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) { - return true; - } else { - return false; - } -} +bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); } -bool Util::IsRoleOfPServer() { - auto role = common::GetEnv(kEnvRole); - if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) { - return true; - } else { - return false; - } -} +bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); } -bool Util::IsRoleOfScheduler() { - auto role = common::GetEnv(kEnvRole); - if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) { - return true; - } else { - return false; - } -} +bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); } void Util::SetInternalEnvVar() { if (IsParamServerMode()) { @@ -163,10 +144,6 @@ std::map Util::AllRankLocalShard(int first_dim, int rank_id, int serve return shard_dims; } -void Util::SetRankId(int rank_id) { rank_id_ = rank_id; } - -int Util::GetRankId() { return rank_id_; } - void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, const size_t first_dim_size, const size_t outer_dim_size, mindspore::kernel::SparseGradient *unique_sparse_grad) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h index 9974482e2bb5ea610ca163b854c24a79015acb18..ce621ec3da0e31033df8b7e5aae4ac11122e6aa8 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.h +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -40,8 +40,6 @@ class Util { static bool is_optimizer(std::string name); static int LocalShard(int first_dim, int rank_id, int server_num); static std::map AllRankLocalShard(int first_dim, int rank_id, int server_num); - static void SetRankId(int rank_id); - static int GetRankId(); static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, const size_t first_dim_size, const size_t outer_dim_size, mindspore::kernel::SparseGradient *unique_sparse_grad); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index 0c4f77f96299c034c2f9395cec1eab948315c93a..3ed245bfce923b8b7cc6abd0acbba0ac3de54377 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -27,6 +27,7 @@ #include "ps/ps.h" #include "frontend/parallel/ps/util.h" #include "backend/kernel_compiler/common_utils.h" +#include "frontend/parallel/ps/ps_context.h" namespace mindspore { namespace parallel { @@ -43,7 +44,7 @@ class WorkerProxy : public ::ps::KVWorker { explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id) : Worker(app_id, customer_id) { server_num_ = ::ps::NumServers(); - Util::SetRankId(::ps::MyRank()); + PSContext::instance()->SetPSRankId(::ps::MyRank()); using std::placeholders::_1; using std::placeholders::_2; using std::placeholders::_3; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index cc43ec6bf22058f30d6bf70c342c60ab39e3a1b1..401d3862845f3399a39a9103a96e2ce1afc84116 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -36,6 +36,7 @@ #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "frontend/parallel/ps/util.h" #endif +#include "frontend/parallel/ps/ps_context.h" namespace py = pybind11; using EnvInstance = mindspore::EnvInstance; @@ -49,6 +50,7 @@ using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; using ParallelContext = mindspore::parallel::ParallelContext; using CostModelContext = mindspore::parallel::CostModelContext; using mindspore::MsCtxParam; +using PSContext = mindspore::parallel::ps::PSContext; // Interface with python PYBIND11_MODULE(_c_expression, m) { @@ -276,9 +278,15 @@ PYBIND11_MODULE(_c_expression, m) { "Finalize gpu collective communication mode."); #endif -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - (void)m.def("get_ps_mode_rank", &mindspore::parallel::ps::Util::GetRankId, "Get Worker and PServer rank id."); -#endif + (void)py::class_>(m, "PSContext") + .def_static("get_instance", &PSContext::instance, "Get PS context instance.") + .def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.") + .def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.") + .def("reset", &PSContext::Reset, "Reset PS context attributes.") + .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.") + .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.") + .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.") + .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id."); (void)py::class_>(m, "OpInfoLoaderPy") .def(py::init()) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 128c8d87e5b7fc77e0f965a76be1faeea2b9feb1..0d5241f4acdd2c3e379417f03b1df1a83eea539a 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================ """Providing interface methods.""" -import os import types from collections import OrderedDict from functools import wraps @@ -25,6 +24,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend from .tensor import Tensor as MsTensor from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor +from ..parallel._ps_context import _is_role_pserver # store ms_function class compiled pipeline cache ms_compile_cache = {} @@ -469,7 +469,7 @@ class _Executor: return self._executor.has_compiled(phase) def __call__(self, obj, *args, phase='predict'): - if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER": + if context.get_context("precompile_only") or _is_role_pserver(): return None return self.run(obj, *args, phase=phase) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index b025e9e95b943d62bcbc9cce8c64b3df3ed63b0c..67c3936a699b42aa7fd8505c49e887e82495df11 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -22,6 +22,7 @@ from .tensor import Tensor, MetaTensor from .._checkparam import _check_str_by_regular from ..parallel._tensor import _get_slice_index from ..parallel._auto_parallel_context import auto_parallel_context +from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched __all__ = ['Parameter', 'ParameterTuple'] @@ -168,8 +169,13 @@ class Parameter(MetaTensor): """For parse check.""" def set_param_ps(self, init_in_server=False): - self.is_param_ps = True - self.init_in_server = init_in_server + if _is_role_worker() or _is_role_pserver() or _is_role_sched(): + self.is_param_ps = True + self.init_in_server = init_in_server + else: + raise RuntimeError("Must complete following two steps before calling set_param_ps: \ + 1. set_ps_context(enable_ps=True) \ + 2. export MS_ROLE environment variable.") @property diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 920488cee403a72e742d06aff9e344c25069a05e..f70616ce0a8a228cd073571ebccf6743a53717ef 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -14,7 +14,7 @@ # ============================================================================ """comm_helper""" -import os +from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from ._hccl_management import load_lib as hccl_load_lib _HCCL_AVAILABLE = False @@ -44,7 +44,6 @@ else: HCCL_WORLD_COMM_GROUP = "hccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group" -MS_ROLE = os.getenv("MS_ROLE") class Backend: """ @@ -113,7 +112,7 @@ def check_parameter_available(func): Wrapper. If not available, raise Error. """ def wrapper(*args, **kargs): - if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + if _is_role_pserver() or _is_role_sched(): return func(*args, **kargs) group = None if "group" in kargs.keys(): @@ -154,7 +153,7 @@ def _get_rank_helper(group, backend): Integer. The local rank id of the calling process. """ rank_id = None - if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + if _is_role_pserver() or _is_role_sched(): rank_id = 0 return rank_id if backend == Backend.HCCL: @@ -213,7 +212,7 @@ def _get_size_helper(group, backend): Integer. The rank size of specified group. """ size = None - if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + if _is_role_pserver() or _is_role_sched(): size = 1 return size if backend == Backend.HCCL: diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index dd0f56e2036181979109c6c779f1634e4b371ab0..53da7291d8c1a90b7070184ab391e8197828eaec 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================ """Communication management API""" -import os from mindspore import context +from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ @@ -29,7 +29,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_BACKEND = Backend("hccl") -MS_ROLE = os.getenv("MS_ROLE") def _get_group(group): @@ -61,7 +60,7 @@ def init(backend_name=None): RuntimeError: If device target is invalid. RuntimeError: If backend is invalid or distributed init fails. """ - if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + if _is_role_pserver() or _is_role_sched(): return if backend_name is None: device_target = context.get_context("device_target") diff --git a/mindspore/context.py b/mindspore/context.py index ea08960182f103c2db9cae7d36bebfb6f7b998cb..9661102d15ef06d1d952172ad71ce2c810f543d3 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -26,9 +26,11 @@ from mindspore._c_expression import MSContext, ms_ctx_param from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context +from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', - 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode'] + 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context', + 'get_ps_context', 'reset_ps_context'] GRAPH_MODE = 0 PYNATIVE_MODE = 1 @@ -569,3 +571,58 @@ class ParallelMode: SEMI_AUTO_PARALLEL = "semi_auto_parallel" AUTO_PARALLEL = "auto_parallel" MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] + +@args_type_check(enable_ps=bool) +def set_ps_context(**kwargs): + """ + Set parameter server training mode context. + + Note: + Some other environment variables should also be set for parameter server training mode. + These environment variables are listed below: + MS_SERVER_NUM # Server number + MS_WORKER_NUM # Worker number + MS_SCHED_HOST # Scheduler IP address + MS_SCHED_PORT # Scheduler port + MS_ROLE # The role of this process: + MS_SCHED represents the scheduler, + MS_WORKER represents the worker, + MS_PSERVER represents the Server + + + Args: + enable_ps (bool): Whether to enable parameter server training mode. + Only after enable_ps is set True, the environment variables will be effective. + Default: False. + + Raises: + ValueError: If input key is not the attribute in parameter server training mode context. + + Examples: + >>> context.set_ps_context(enable_ps=True) + """ + _set_ps_context(**kwargs) + + +def get_ps_context(attr_key): + """ + Get parameter server training mode context attribute value according to the key. + + Args: + attr_key (str): The key of the attribute. + + Returns: + Returns attribute value according to the key. + + Raises: + ValueError: If input key is not attribute in auto parallel context. + """ + return _get_ps_context(attr_key) + +def reset_ps_context(): + """ + Reset parameter server training mode context attributes to the default values: + + - enable_ps: False. + """ + _reset_ps_context() diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e62f6530d9c148076e9a86bf3f240f1cce89f9 --- /dev/null +++ b/mindspore/parallel/_ps_context.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Context for parameter server training mode""" + +from mindspore._c_expression import PSContext + +_ps_context = None + + +def ps_context(): + """ + Get the global _ps_context, if it is not created, create a new one. + + Returns: + _ps_context, the global parameter server training mode context. + """ + global _ps_context + if _ps_context is None: + _ps_context = PSContext.get_instance() + return _ps_context + +_set_ps_context_func_map = { + "enable_ps": ps_context().set_ps_enable +} + +_get_ps_context_func_map = { + "enable_ps": ps_context().is_ps_enabled +} + +def _get_ps_mode_rank(): + ps_rank = ps_context().ps_rank_id() + if ps_rank == -1: + raise RuntimeError("The parameter server mode training is not enabled yet.") + return ps_rank + +def _set_ps_context(**kwargs): + """ + Set parameter server training mode context. + + Note: + Some other environment variables should also be set for parameter server training mode. + These environment variables are listed below: + MS_SERVER_NUM # Server number + MS_WORKER_NUM # Worker number + MS_SCHED_HOST # Scheduler IP address + MS_SCHED_PORT # Scheduler port + MS_ROLE # The role of this process: + MS_SCHED represents the scheduler, + MS_WORKER represents the worker, + MS_PSERVER represents the Server + + + Args: + enable_ps (bool): Whether to enable parameter server training mode. + Only after enable_ps is set True, the environment variables will be effective. + Default: False. + + Raises: + ValueError: If input key is not the attribute in parameter server training mode context. + + Examples: + >>> context.set_ps_context(enable_ps=True) + """ + for key, value in kwargs.items(): + if key not in _set_ps_context_func_map: + raise ValueError("Set PS context keyword %s is not recognized!" % key) + set_func = _set_ps_context_func_map[key] + set_func(value) + +def _get_ps_context(attr_key): + """ + Get parameter server training mode context attribute value according to the key. + + Args: + attr_key (str): The key of the attribute. + + Returns: + Returns attribute value according to the key. + + Raises: + ValueError: If input key is not attribute in auto parallel context. + """ + if key not in _get_ps_context_func_map: + raise ValueError("Get PS context keyword %s is not recognized!" % key) + get_func = _get_ps_context_func_map[attr_key] + get_func(attr_key) + +def _reset_ps_context(): + """ + Reset parameter server training mode context attributes to the default values: + + - enable_ps: False. + """ + ps_context().reset() + +def _is_role_worker(): + return ps_context().is_role_worker() + +def _is_role_pserver(): + return ps_context().is_role_pserver() + +def _is_role_sched(): + return ps_context().is_role_sched() diff --git a/mindspore/parallel/_ps_utils.py b/mindspore/parallel/_ps_utils.py deleted file mode 100644 index 7a62d41f45cc40bf49c496aa57baa3053213b98e..0000000000000000000000000000000000000000 --- a/mindspore/parallel/_ps_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Utils for parameter server training mode""" - -from mindspore._c_expression import get_ps_mode_rank - -def _get_ps_mode_rank(): - ps_rank = get_ps_mode_rank() - if ps_rank == -1: - raise RuntimeError("The parameter server mode training is not launched yet.") - return ps_rank diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index a76c73e1122ebc0d701e5caf559bee2b2ca973df..ce3892a368efa771345cb6aec8939ce333300d96 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -24,6 +24,7 @@ from mindspore import log as logger from mindspore._checkparam import check_bool, check_int_non_negative from mindspore.train._utils import _make_directory from mindspore.train.serialization import save_checkpoint, _save_graph +from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank from ._callback import Callback, set_cur_net @@ -280,8 +281,7 @@ class ModelCheckpoint(Callback): if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" - if os.getenv("MS_ROLE") == "MS_PSERVER": - from mindspore.parallel._ps_utils import _get_ps_mode_rank + if _is_role_pserver(): cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 3a1286d13060c27b9808e300884dd51614d53a7c..af8ecb9164426f69d06c38be37b028047a4bbbdd 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -27,6 +27,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check +from ..parallel._ps_context import _is_role_pserver, _is_role_sched from ..nn.metrics import Loss from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell @@ -378,8 +379,7 @@ class Model: cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.train_dataset_element = None cb_params.network = self._network - ms_role = os.getenv("MS_ROLE") - if ms_role in ("MS_PSERVER", "MS_SCHED"): + if _is_role_pserver() or _is_role_sched(): epoch = 1 # build callback list @@ -516,7 +516,7 @@ class Model: self._loss_scale_manager.update_loss_scale(overflow) list_callback.step_end(run_context) - if os.getenv("MS_ROLE") == "MS_PSERVER": + if _is_role_pserver(): os._exit(0) should_stop = should_stop or run_context.get_stop_requested() if should_stop: diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 0f22a261929d705b88a37f573d5fde3940716459..0cc619b9bad9dd06c7d597ca9c5679e9f4ff1f95 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -70,6 +70,7 @@ if __name__ == '__main__': # init context context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + context.set_ps_context(enable_ps=True) if args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) diff --git a/model_zoo/official/nlp/bert_thor/src/model_thor.py b/model_zoo/official/nlp/bert_thor/src/model_thor.py index f47e8c368979a725ffd35018a4ae2a2fa2e15aff..710bd7a9f0bb5203953784d2a587a1bb21fa847f 100644 --- a/model_zoo/official/nlp/bert_thor/src/model_thor.py +++ b/model_zoo/official/nlp/bert_thor/src/model_thor.py @@ -14,7 +14,6 @@ # ============================================================================ """Model.""" import math -import os from collections.abc import Iterable import numpy as np @@ -405,9 +404,6 @@ class Model: cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.train_dataset_element = None cb_params.network = self._network - ms_role = os.getenv("MS_ROLE") - if ms_role in ("MS_PSERVER", "MS_SCHED"): - epoch = 1 # build callback list with _CallbackManager(callbacks) as list_callback: diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py index 232488ceffd0cebc968c4fe9b8bc0b82e44fc5b4..056ea404a32d76811879d628e3d7b08cdbcd77e2 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py @@ -118,6 +118,7 @@ if __name__ == "__main__": wide_deep_config.argparse_init() context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) + context.set_ps_context(enable_ps=True) init() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=get_group_size()) diff --git a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py index aecf8d781d577396b3a842bcc94f341349320787..1a4b0765cb88fb2519ad3ca957091fcf15d87626 100644 --- a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py +++ b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py @@ -26,6 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam from mindspore.ops import operations as P from mindspore.common.initializer import TruncatedNormal +from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker parser = argparse.ArgumentParser(description="test_sparse_embedding") parser.add_argument("--device_target", type=str, default="Ascend") @@ -34,6 +35,7 @@ device_target = args.device_target context.set_context( mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True ) +context.set_ps_context(enable_ps=True) def fc_with_initialize(input_channels, out_channels): @@ -81,7 +83,7 @@ def do_sparse_embedding(ps=False): for _ in range(epoch): data = Tensor(np.random.randint(0, 15, (32, 3), np.int32)) label = Tensor(np.random.randint(0, 9, (32), np.int32)) - if envs.get("MS_ROLE") == "MS_PSERVER": + if _is_role_pserver(): train_network(data, label) sys.exit() else: @@ -96,10 +98,10 @@ if __name__ == "__main__": np.random.seed(0) ps_loss = do_sparse_embedding(True) - if envs.get("MS_ROLE") == "MS_WORKER": - envs["MS_ROLE"] = "" + if _is_role_worker(): + context.reset_ps_context() np.random.seed(0) no_ps_loss = do_sparse_embedding() - envs["MS_ROLE"] = "MS_WORKER" + context.set_ps_context(enable_ps=True) assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6) diff --git a/tests/st/ps/full_ps/test_full_ps_lenet.py b/tests/st/ps/full_ps/test_full_ps_lenet.py index aca875f6fccf0e1791bc8f7d4d0e4de3bc0f9fa8..2ee3caef94301dc4f74d306030097bb42fe5583a 100644 --- a/tests/st/ps/full_ps/test_full_ps_lenet.py +++ b/tests/st/ps/full_ps/test_full_ps_lenet.py @@ -35,6 +35,7 @@ args, _ = parser.parse_known_args() device_target = args.device_target dataset_path = args.dataset_path context.set_context(mode=context.GRAPH_MODE, device_target=device_target) +context.set_ps_context(enable_ps=True) def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): """weight initial for conv layer""" diff --git a/tests/st/ps/multi_full_ps/test_multi_full_ps.py b/tests/st/ps/multi_full_ps/test_multi_full_ps.py index 99e8bb2322a973814c2dd12901fdfe40aadfbed6..ae5437417add3bd760185d27d2a751f747fb0a70 100644 --- a/tests/st/ps/multi_full_ps/test_multi_full_ps.py +++ b/tests/st/ps/multi_full_ps/test_multi_full_ps.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ +import sys import argparse import numpy as np @@ -22,6 +23,7 @@ from mindspore.common.initializer import TruncatedNormal from mindspore import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.communication.management import init, get_group_size +from mindspore.parallel._ps_context import _is_role_pserver # from resnet import resnet50 parser = argparse.ArgumentParser(description="test_ps_lenet") @@ -29,6 +31,7 @@ parser.add_argument("--device_target", type=str, default="Ascend") args, _ = parser.parse_known_args() device_target = args.device_target context.set_context(mode=context.GRAPH_MODE, device_target=device_target) +context.set_ps_context(enable_ps=True) if device_target == "GPU": init() @@ -106,6 +109,10 @@ if __name__ == "__main__": for _ in range(epoch): data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32)) - loss = train_network(data, label).asnumpy() - losses.append(loss) + if _is_role_pserver(): + train_network(data, label) + sys.exit() + else: + loss = train_network(data, label).asnumpy() + losses.append(loss) print(losses)