diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 20445f9c6a75e970d023f7b5e2261b0c5bbddecf..6624e8e5667d26ca4bcf3a321b306c05221a6f66 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -42,15 +42,12 @@ std::shared_ptr ParallelContext::GetInstance() { return inst_context_; } -ParallelContext::ParallelContext() { - communication_backend_ = HCCL_BACKEND; - Reset(); -} +ParallelContext::ParallelContext() { Reset(); } void ParallelContext::Reset() { mirror_mean_ = false; full_batch_ = false; - cast_before_mirror_ = true; + gradient_fp32_sync_ = true; loss_repeated_mean_ = true; device_num_ = 1; global_rank_ = 0; @@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_ void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } -void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } +void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } -void ParallelContext::set_communication_backend(const std::string &communication_backend) { - communication_backend_ = communication_backend; -} - bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); if (iter == PARALLEL_MODE_LIST.end()) { diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 34363726411dd9b154e6e03522679eccaeaa6659..828300af1ccb06c0702e93d92d03df2cd1002604 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -58,8 +58,8 @@ class ParallelContext { void set_full_batch(bool full_batch); bool full_batch() const { return full_batch_; } - void set_cast_before_mirror(bool cast_before_mirror); - bool cast_before_mirror() const { return cast_before_mirror_; } + void set_gradient_fp32_sync(bool gradient_fp32_sync); + bool gradient_fp32_sync() const { return gradient_fp32_sync_; } void set_loss_repeated_mean(bool loss_repeated_mean); bool loss_repeated_mean() const { return loss_repeated_mean_; } @@ -70,9 +70,6 @@ class ParallelContext { void set_global_rank(int32_t global_rank); int32_t global_rank() const { return global_rank_; } - void set_communication_backend(const std::string &communication_backend); - std::string communication_backend() const { return communication_backend_; } - bool set_parallel_mode(const std::string ¶llel_mode); std::string parallel_mode() const { return parallel_mode_; } @@ -112,11 +109,10 @@ class ParallelContext { static std::shared_ptr inst_context_; bool mirror_mean_; bool full_batch_; - bool cast_before_mirror_; + bool gradient_fp32_sync_; bool loss_repeated_mean_; int32_t device_num_; int32_t global_rank_; - std::string communication_backend_; std::string parallel_mode_; std::string strategy_search_mode_; bool parameter_broadcast_; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 1d046c04b0c1e2fc29c42b72d78ad58da7ca6a04..b69b83f257607b8a40a339f6c762666292187bab 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -43,6 +43,7 @@ #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "utils/comm_manager.h" #include "utils/symbolic.h" +#include "utils/ms_context.h" using mindspore::tensor::Tensor; @@ -869,8 +870,8 @@ std::pair FindCNode(const AnfNodePtr &anode, const std::string & } bool IsCastBeforMirror(const CNodePtr &node, size_t index) { - // only if cast_before_mirror is true, pre node is cast and type is not float32 return true - if (!ParallelContext::GetInstance()->cast_before_mirror()) { + // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true + if (!ParallelContext::GetInstance()->gradient_fp32_sync()) { return false; } auto pre_node = node->input(index); @@ -2421,13 +2422,17 @@ Status ParallelInit() { MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); int32_t device_num = ParallelContext::GetInstance()->device_num(); int32_t global_rank = ParallelContext::GetInstance()->global_rank(); - std::string backend = ParallelContext::GetInstance()->communication_backend(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::string backend = ms_context->get_param(MS_CTX_DEVICE_TARGET); std::string world_group; - - if (backend == HCCL_BACKEND) { + std::string communication_backend; + if (backend == kAscendDevice || backend == kDavinciDevice) { world_group = HCCL_WORLD_GROUP; - } else if (backend == NCCL_BACKEND) { + communication_backend = HCCL_BACKEND; + } else if (backend == kGPUDevice) { world_group = NCCL_WORLD_GROUP; + communication_backend = NCCL_BACKEND; } else { MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; } @@ -2450,14 +2455,14 @@ Status ParallelInit() { MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; } - if (!InitDevice(device_num, global_rank, backend)) { + if (!InitDevice(device_num, global_rank, communication_backend)) { MS_LOG(ERROR) << "Init device failed"; return FAILED; } MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() - << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); + << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); return SUCCESS; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 2ee67882f44bbddb2b4211bdf93f7077b90ff626..2a5d88d65c71c412752326ff13203c7604328364 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -209,12 +209,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") - .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") - .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") + .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.") + .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.") .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") - .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") - .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index be07b538610d6813c49b8cfa17a51416922e512d..dd0f56e2036181979109c6c779f1634e4b371ab0 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -15,7 +15,6 @@ """Communication management API""" import os from mindspore import context -from mindspore.parallel._auto_parallel_context import auto_parallel_context from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ @@ -86,9 +85,6 @@ def init(backend_name=None): else: raise RuntimeError("Backend name {} is not supported.".format(backend_name)) - auto_parallel_context().set_communication_backend(backend_name) - - def release(): """ Release distributed resource. e.g., hccl/nccl. diff --git a/mindspore/context.py b/mindspore/context.py index 7788a4f367b089f8f5ff11dc8265d6cb593b3c86..4bfe75c8cef191f436473c570667f2782c5462b7 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -434,7 +434,7 @@ def _context(): return _k_context -@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, +@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) def set_auto_parallel_context(**kwargs): @@ -454,9 +454,9 @@ def set_auto_parallel_context(**kwargs): global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. "stand_alone" do not support mirror_mean. Default: False. - cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. + gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.. "stand_alone", "data_parallel" and "hybrid_parallel" do not support - cast_before_mirror. Default: True. + gradient_fp32_sync. Default: True. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". @@ -492,7 +492,7 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(device_num=8) >>> context.set_auto_parallel_context(global_rank=0) >>> context.set_auto_parallel_context(mirror_mean=True) - >>> context.set_auto_parallel_context(cast_before_mirror=False) + >>> context.set_auto_parallel_context(gradient_fp32_sync=False) >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") @@ -524,7 +524,7 @@ def reset_auto_parallel_context(): - device_num: 1. - global_rank: 0. - mirror_mean: False. - - cast_before_mirror: True. + - gradient_fp32_sync: True. - parallel_mode: "stand_alone". - parameter_broadcast: False. - strategy_ckpt_load_file: "". diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index e2369c4aa69ddeddfc31206b42a6fef0a1001311..0cd11d7fb8e764046213076db5fe750a739c4fcf 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -113,24 +113,24 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_mirror_mean() - def set_cast_before_mirror(self, cast_before_mirror): + def set_gradient_fp32_sync(self, gradient_fp32_sync): """ - Set cast_before_mirror. + Set gradient_fp32_sync. Note: - If cast_before_mirror is true, + If gradient_fp32_sync is true, it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. Args: - cast_before_mirror (bool): The cast_before_mirror flag. + gradient_fp32_sync (bool): The gradient_fp32_sync flag. """ self.check_context_handle() - self._context_handle.set_cast_before_mirror(cast_before_mirror) + self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) - def get_cast_before_mirror(self): - """Get cast_before_mirror flag.""" + def get_gradient_fp32_sync(self): + """Get gradient_fp32_sync flag.""" self.check_context_handle() - return self._context_handle.get_cast_before_mirror() + return self._context_handle.get_gradient_fp32_sync() def set_loss_repeated_mean(self, loss_repeated_mean): """ @@ -152,21 +152,6 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_loss_repeated_mean() - def set_communication_backend(self, communication_backend): - """ - Set communication backend. - - Args: - communication_backend (str): The communication backend. - """ - self.check_context_handle() - self._context_handle.set_communication_backend(communication_backend) - - def get_communication_backend(self): - """Get communication backend.""" - self.check_context_handle() - return self._context_handle.get_communication_backend() - def set_parallel_mode(self, parallel_mode): """ Set parallel mode for auto parallel. @@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = { "device_num": auto_parallel_context().set_device_num, "global_rank": auto_parallel_context().set_global_rank, "mirror_mean": auto_parallel_context().set_mirror_mean, - "cast_before_mirror": auto_parallel_context().set_cast_before_mirror, + "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, "parallel_mode": auto_parallel_context().set_parallel_mode, "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, @@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = { "device_num": auto_parallel_context().get_device_num, "global_rank": auto_parallel_context().get_global_rank, "mirror_mean": auto_parallel_context().get_mirror_mean, - "cast_before_mirror": auto_parallel_context().get_cast_before_mirror, + "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, "parallel_mode": auto_parallel_context().get_parallel_mode, "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, @@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = { "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} -@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, +@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) @@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs): global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated - calculations. Default: True. - cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True. + calculations. Default: True. + gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. + Default: True. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". @@ -577,7 +563,7 @@ def _reset_auto_parallel_context(): - device_num: 1. - global_rank: 0. - mirror_mean: False. - - cast_before_mirror: True. + - gradient_fp32_sync: True. - parallel_mode: "stand_alone". - parameter_broadcast: False. - strategy_ckpt_load_file: "" diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index e44f824ce2c277913e7b07c79107baeee65fb47f..01d17e910b9ba630456d0fed77d73098b4d241f8 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -61,7 +61,7 @@ def get_rank_id(group=None): def get_rank_size(group=None): hccl = Hccl() - if group is None: + if group is None or "nccl_world_group" in group: return hccl.rank_size if isinstance(group, str): return int(group.split("-")[0]) diff --git a/tests/ut/python/parallel/test_element_wise_function.py b/tests/ut/python/parallel/test_element_wise_function.py index 9226e3e43c6e072bd5727797a6cffc10caf38b56..120cd8c3daa61dd1d7c791b524a2bb8f06907170 100644 --- a/tests/ut/python/parallel/test_element_wise_function.py +++ b/tests/ut/python/parallel/test_element_wise_function.py @@ -830,7 +830,7 @@ def test_matmul_cast(): compile_net(net, x, y, b) -def test_cast_before_mirror(): +def test_gradient_fp32_sync(): class Net(nn.Cell): def __init__(self, strategy1): super().__init__() @@ -843,7 +843,7 @@ def test_cast_before_mirror(): out = self.matmul(out, b) return out - context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) + context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True) strategy1 = ((2, 2), (2, 2)) net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") @@ -854,7 +854,7 @@ def test_cast_before_mirror(): compile_net(net, x, y, b) -def test_cast_before_mirror1(): +def test_gradient_fp32_sync1(): class Net(nn.Cell): def __init__(self, strategy1): super().__init__() @@ -867,7 +867,7 @@ def test_cast_before_mirror1(): out = self.matmul(out, b) return out - context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) + context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True) strategy1 = ((2, 2), (2, 2)) net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") @@ -878,7 +878,7 @@ def test_cast_before_mirror1(): compile_net(net, x, y, b) -def test_cast_before_mirror2(): +def test_gradient_fp32_sync2(): class Net(nn.Cell): def __init__(self, strategy1): super().__init__() @@ -891,7 +891,7 @@ def test_cast_before_mirror2(): out = self.matmul(out, b) return out - context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=False) + context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=False) strategy1 = ((2, 2), (2, 2)) net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") @@ -902,7 +902,7 @@ def test_cast_before_mirror2(): compile_net(net, x, y, b) -def test_cast_before_mirror3(): +def test_gradient_fp32_sync3(): class Net(nn.Cell): def __init__(self, strategy1): super().__init__() diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index 19187cb262c10b76f8ed22304c1317888ef94c54..ff69b3bee84883d9117370f8f2a71d60693d9be4 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context def test_set_auto_parallel_context(): - context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False, + context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False, parallel_mode="auto_parallel", parameter_broadcast=False) device_num = context.get_auto_parallel_context("device_num") global_rank = context.get_auto_parallel_context("global_rank") mirror_mean = context.get_auto_parallel_context("mirror_mean") - cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") + gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") parallel_mode = context.get_auto_parallel_context("parallel_mode") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") assert device_num == 4 assert global_rank == 3 assert mirror_mean - assert not cast_before_mirror + assert not gradient_fp32_sync assert parallel_mode == "auto_parallel" assert not parameter_broadcast - auto_parallel_context().set_communication_backend("hccl") - backend = auto_parallel_context().get_communication_backend() - assert backend == "hccl" - auto_parallel_context().set_device_num(4) device_num = auto_parallel_context().get_device_num() device_num_is_set = auto_parallel_context().get_device_num_is_set() @@ -53,9 +49,9 @@ def test_set_auto_parallel_context(): mirror_mean = auto_parallel_context().get_mirror_mean() assert mirror_mean - auto_parallel_context().set_cast_before_mirror(False) - cast_before_mirror = auto_parallel_context().get_cast_before_mirror() - assert not cast_before_mirror + auto_parallel_context().set_gradient_fp32_sync(False) + gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync() + assert not gradient_fp32_sync parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() assert parameter_broadcast_is_set @@ -91,7 +87,7 @@ def test_reset_auto_parallel_context(): device_num = context.get_auto_parallel_context("device_num") global_rank = context.get_auto_parallel_context("global_rank") mirror_mean = context.get_auto_parallel_context("mirror_mean") - cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") + gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") parallel_mode = context.get_auto_parallel_context("parallel_mode") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") device_num_is_set = auto_parallel_context().get_device_num_is_set() @@ -99,7 +95,7 @@ def test_reset_auto_parallel_context(): assert device_num == 1 assert global_rank == 0 assert not mirror_mean - assert cast_before_mirror + assert gradient_fp32_sync assert parallel_mode == "stand_alone" assert not parameter_broadcast assert not device_num_is_set