未验证 提交 3b5064d6 编写于 作者: C caozhou 提交者: GitHub

update instantiate for auto parallel (#46883)

上级 0773639a
...@@ -21,7 +21,7 @@ from ..collective import _get_global_env ...@@ -21,7 +21,7 @@ from ..collective import _get_global_env
from ..collective import _new_ring_id from ..collective import _new_ring_id
from ...fluid.framework import _non_static_mode from ...fluid.framework import _non_static_mode
from ...fluid.layers.tensor import fill_constant from ...fluid.layers.tensor import fill_constant
from paddle.fluid.framework import _enable_legacy_dygraph from paddle import _legacy_C_ops
def get_all_process_groups(): def get_all_process_groups():
...@@ -145,14 +145,15 @@ class ProcessGroup: ...@@ -145,14 +145,15 @@ class ProcessGroup:
# TODO(shenliang03): This is a temporary solution to solve the problem of # TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group # hang caused by cross-creation of new_group
paddle.disable_static() paddle.disable_static()
_enable_legacy_dygraph()
paddle.set_device('gpu:%d' % paddle.set_device('gpu:%d' %
paddle.distributed.ParallelEnv().dev_id) paddle.distributed.ParallelEnv().dev_id)
tmp = paddle.to_tensor( tmp = paddle.to_tensor(
[1], dtype="int32") if _non_static_mode() else fill_constant( [1], dtype="int32") if _non_static_mode() else fill_constant(
[0], dtype="int32", value="1") [0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, sync_op=True, group=self) # use legacy ops
paddle.distributed.wait(tmp, group=self) _legacy_C_ops.c_allreduce_sum_(tmp, 'use_calc_stream', True,
'ring_id', self.id)
_legacy_C_ops.c_sync_calc_stream(tmp, tmp)
paddle.enable_static() paddle.enable_static()
self._is_instantiate = True self._is_instantiate = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册