diff --git a/mindspore/parallel/__init__.py b/mindspore/parallel/__init__.py index c79704f110431d5df5a6d3cfb8689f11b6d4bf38..79d8e67a8dcd2e4f5f1a9ea0c8d47bcd7651c8fc 100644 --- a/mindspore/parallel/__init__.py +++ b/mindspore/parallel/__init__.py @@ -15,9 +15,7 @@ """ This interface is ONLY used in Auto-parallel procedure. """ -from .dp_allreduce_fusion import set_fusion_strategy_by_idx, set_fusion_strategy_by_size from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ set_algo_parameters -__all__ = ["set_fusion_strategy_by_idx", "set_fusion_strategy_by_size", "get_algo_parameters", - "reset_algo_parameters", "set_algo_parameters"] +__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 3564ad4395f7217e758ed61d7b9e0f232ffb35e8..c99ac4a3c7372f02f6c8eea6aab7fe564a2fef0d 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -14,6 +14,8 @@ # ============================================================================ """Context of auto parallel""" import threading +import mindspore.context as context +from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore._c_expression import AutoParallelContext from mindspore._extends.pynative_helper import args_type_check @@ -219,13 +221,15 @@ class _AutoParallelContext: indices (list): Indices list. Raises: - ValueError: If type of indices item is not int. + TypeError: If type of indices item is not int. """ self.check_context_handle() for index in indices: if not isinstance(index, int): raise TypeError('indices has invalid value') - return self._context_handle.set_all_reduce_fusion_split_indices(indices) + self._context_handle.set_all_reduce_fusion_split_indices(indices) + if context.get_context("device_target") == "Ascend": + _set_fusion_strategy_by_idx(indices) def get_all_reduce_fusion_split_indices(self): """Get allreduce fusion split indices.""" @@ -240,13 +244,15 @@ class _AutoParallelContext: sizes (list): Sizes list. Raises: - ValueError: If type of sizes item is not int. + TypeError: If type of sizes item is not int. """ self.check_context_handle() for size in sizes: if not isinstance(size, int): raise TypeError('sizes has invalid value') - return self._context_handle.set_all_reduce_fusion_split_sizes(sizes) + self._context_handle.set_all_reduce_fusion_split_sizes(sizes) + if context.get_context("device_target") == "Ascend": + _set_fusion_strategy_by_size(sizes) def get_all_reduce_fusion_split_sizes(self): """Get allreduce fusion split sizes.""" diff --git a/mindspore/parallel/dp_allreduce_fusion.py b/mindspore/parallel/_dp_allreduce_fusion.py similarity index 94% rename from mindspore/parallel/dp_allreduce_fusion.py rename to mindspore/parallel/_dp_allreduce_fusion.py index 979823bd806ca5ccf918709861d42ea4f1aee67c..3c7039dbd6ddd408afe3ac54b49625385139852d 100644 --- a/mindspore/parallel/dp_allreduce_fusion.py +++ b/mindspore/parallel/_dp_allreduce_fusion.py @@ -43,7 +43,7 @@ def _c_array(ctype, values): return (ctype * len(values))(*values) -def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): +def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): """ A function set gradient segment strategy according to the index list. @@ -100,7 +100,7 @@ def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): raise RuntimeError('Allreduce split error') -def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): +def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): """ A function set gradient segment strategy according to the data size percentage list.