提交 ecc168c7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!189 Integrate two allreduce fusion set interfaces into one

Merge pull request !189 from yao_yf/parallel_interface_organize
......@@ -15,9 +15,7 @@
"""
This interface is ONLY used in Auto-parallel procedure.
"""
from .dp_allreduce_fusion import set_fusion_strategy_by_idx, set_fusion_strategy_by_size
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters
__all__ = ["set_fusion_strategy_by_idx", "set_fusion_strategy_by_size", "get_algo_parameters",
"reset_algo_parameters", "set_algo_parameters"]
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
......@@ -14,6 +14,8 @@
# ============================================================================
"""Context of auto parallel"""
import threading
import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore._c_expression import AutoParallelContext
from mindspore._extends.pynative_helper import args_type_check
......@@ -219,13 +221,15 @@ class _AutoParallelContext:
indices (list): Indices list.
Raises:
ValueError: If type of indices item is not int.
TypeError: If type of indices item is not int.
"""
self.check_context_handle()
for index in indices:
if not isinstance(index, int):
raise TypeError('indices has invalid value')
return self._context_handle.set_all_reduce_fusion_split_indices(indices)
self._context_handle.set_all_reduce_fusion_split_indices(indices)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_idx(indices)
def get_all_reduce_fusion_split_indices(self):
"""Get allreduce fusion split indices."""
......@@ -240,13 +244,15 @@ class _AutoParallelContext:
sizes (list): Sizes list.
Raises:
ValueError: If type of sizes item is not int.
TypeError: If type of sizes item is not int.
"""
self.check_context_handle()
for size in sizes:
if not isinstance(size, int):
raise TypeError('sizes has invalid value')
return self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_size(sizes)
def get_all_reduce_fusion_split_sizes(self):
"""Get allreduce fusion split sizes."""
......
......@@ -43,7 +43,7 @@ def _c_array(ctype, values):
return (ctype * len(values))(*values)
def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
"""
A function set gradient segment strategy according to the index list.
......@@ -100,7 +100,7 @@ def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
raise RuntimeError('Allreduce split error')
def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
"""
A function set gradient segment strategy according to the data size percentage list.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册