diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 10ff0d36fcee9554d3a398013b4b3de554fd0972..ba63a71643466bdd1e424889962aae1bc1b4d514 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -21,7 +21,7 @@ from ..collective import _get_global_env from ..collective import _new_ring_id from ...fluid.framework import _non_static_mode from ...fluid.layers.tensor import fill_constant -from paddle.fluid.framework import _enable_legacy_dygraph +from paddle import _legacy_C_ops def get_all_process_groups(): @@ -145,14 +145,15 @@ class ProcessGroup: # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by cross-creation of new_group paddle.disable_static() - _enable_legacy_dygraph() paddle.set_device('gpu:%d' % paddle.distributed.ParallelEnv().dev_id) tmp = paddle.to_tensor( [1], dtype="int32") if _non_static_mode() else fill_constant( [0], dtype="int32", value="1") - paddle.distributed.all_reduce(tmp, sync_op=True, group=self) - paddle.distributed.wait(tmp, group=self) + # use legacy ops + _legacy_C_ops.c_allreduce_sum_(tmp, 'use_calc_stream', True, + 'ring_id', self.id) + _legacy_C_ops.c_sync_calc_stream(tmp, tmp) paddle.enable_static() self._is_instantiate = True