From 4ff418084c9a98a8ecc6292b1ce1ac5f02bd782d Mon Sep 17 00:00:00 2001 From: lirongzhen Date: Wed, 22 Apr 2020 19:23:45 +0800 Subject: [PATCH] enable/disable allreduce_fusion --- .../allreduce_fusion/step_allreduce_fusion.cc | 3 ++- mindspore/ccsrc/parallel/context.cc | 1 + mindspore/ccsrc/parallel/context.h | 5 +++++ mindspore/ccsrc/pipeline/init.cc | 4 ++++ mindspore/parallel/_auto_parallel_context.py | 17 +++++++++++++++++ mindspore/parallel/_utils.py | 5 +++++ tests/ut/python/parallel/__init__.py | 2 ++ .../ut/python/parallel/test_allreduce_fusion.py | 3 ++- 8 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc index 8ab089521..23ec9da87 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc @@ -31,10 +31,11 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion(); // assume no change to graph bool changes = false; // control whether use model_parallel mode - if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || + if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { return changes; } diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index bc4aca896..4eb79772d 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -55,6 +55,7 @@ void ParallelContext::Reset() { parallel_mode_ = STAND_ALONE; parameter_broadcast_ = false; parameter_broadcast_is_set_ = false; + enable_all_reduce_fusion_ = false; } void ParallelContext::set_device_num(int32_t device_num) { diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 64261cb96..095a50f7b 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -80,6 +80,10 @@ class ParallelContext { const std::vector all_reduce_fusion_split_indices() const; void set_all_reduce_fusion_split_sizes(const std::vector sizes); const std::vector all_reduce_fusion_split_sizes() const; + void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { + enable_all_reduce_fusion_ = enable_all_reduce_fusion; + } + bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } void Reset(); @@ -98,6 +102,7 @@ class ParallelContext { bool device_num_is_set_; bool global_rank_is_set_; bool parameter_broadcast_is_set_; + bool enable_all_reduce_fusion_; std::vector all_reduce_fusion_split_indices_; std::vector all_reduce_fusion_split_sizes_; }; diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 86e6d436b..98874f857 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -183,6 +183,10 @@ PYBIND11_MODULE(_c_expression, m) { "Set all reduce fusion split sizes.") .def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes, "Get all reduce fusion split sizes.") + .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, + "Set enable/disable all reduce fusion.") + .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion, + "Get enable/disable all reduce fusion.") .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.") .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, "Get parameter broadcast is set.") diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index bf4b99085..0608989d9 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -259,6 +259,23 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_all_reduce_fusion_split_sizes() + def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): + """ + Set enable/disable all reduce fusion. + + Args: + enable_all_reduce_fusion (bool): Enable/disable all reduce fusion. + """ + self.check_context_handle() + if not isinstance(enable_all_reduce_fusion, bool): + raise TypeError('enable_all_reduce_fusion is invalid type') + self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion) + + def get_enable_all_reduce_fusion(self): + """Get all reduce fusion flag.""" + self.check_context_handle() + return self._context_handle.get_enable_all_reduce_fusion() + def get_device_num_is_set(self): """Get device number is set or not.""" self.check_context_handle() diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 3ce5463ed..cb3a0c0ac 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -117,6 +117,7 @@ _cast_before_mirror = None _loss_repeated_mean = None _communication_backend = None _has_checkpointed = False +_enable_all_reduce_fusion = None def _checkpoint_auto_parallel_context(): @@ -133,6 +134,7 @@ def _checkpoint_auto_parallel_context(): global _cast_before_mirror global _loss_repeated_mean global _communication_backend + global _enable_all_reduce_fusion _parallel_mode = auto_parallel_context().get_parallel_mode() _device_num = _get_device_num() _global_rank = _get_global_rank() @@ -141,6 +143,7 @@ def _checkpoint_auto_parallel_context(): _cast_before_mirror = auto_parallel_context().get_cast_before_mirror() _loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean() _communication_backend = auto_parallel_context().get_communication_backend() + _enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion() _has_checkpointed = True @@ -154,10 +157,12 @@ def _restore_auto_parallel_context(): global _cast_before_mirror global _loss_repeated_mean global _communication_backend + global _enable_all_reduce_fusion _set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank, parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean, cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean) auto_parallel_context().set_communication_backend(_communication_backend) + auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion) def _reset_checkpoint_auto_parallel_context(): diff --git a/tests/ut/python/parallel/__init__.py b/tests/ut/python/parallel/__init__.py index c08f8e247..b26962bc3 100644 --- a/tests/ut/python/parallel/__init__.py +++ b/tests/ut/python/parallel/__init__.py @@ -13,10 +13,12 @@ # limitations under the License. import mindspore.context as context +from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._utils import _reset_op_id def setup_module(module): + auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) _reset_op_id() diff --git a/tests/ut/python/parallel/test_allreduce_fusion.py b/tests/ut/python/parallel/test_allreduce_fusion.py index fcbee1058..b8bf9ccc0 100644 --- a/tests/ut/python/parallel/test_allreduce_fusion.py +++ b/tests/ut/python/parallel/test_allreduce_fusion.py @@ -23,7 +23,7 @@ from tests.dataset_mock import MindData from mindspore import context from mindspore.common.api import _executor from mindspore.parallel import _cost_model_context as cost_model_context - +from mindspore.parallel._auto_parallel_context import auto_parallel_context class Dataset(MindData): @@ -105,6 +105,7 @@ def train_common(net): epoch_size = 2 device_num=4 context.reset_auto_parallel_context() + auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, parameter_broadcast=False) context.set_context(mode=context.GRAPH_MODE) -- GitLab