提交 699166e5 编写于 作者: K kswang

default fusion group for ge

上级 961af9fe
...@@ -274,10 +274,7 @@ class _AutoParallelContext: ...@@ -274,10 +274,7 @@ class _AutoParallelContext:
self._context_handle.set_all_reduce_fusion_split_indices(indices, group) self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
if group == "": _set_fusion_strategy_by_idx(indices)
_set_fusion_strategy_by_idx(indices)
else:
_set_fusion_strategy_by_idx(indices, group)
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"):
""" """
...@@ -330,10 +327,7 @@ class _AutoParallelContext: ...@@ -330,10 +327,7 @@ class _AutoParallelContext:
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
if group == "": _set_fusion_strategy_by_size(sizes)
_set_fusion_strategy_by_size(sizes)
else:
_set_fusion_strategy_by_size(sizes, group)
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册