未验证 提交 ae45d981 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix collective_global_gather (#43090)

* [Eager] fix collective_global_gather

* fix eager_ode = 1
上级 2785f876
......@@ -23,6 +23,7 @@ import unittest
import paddle.fluid.layers as layers
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import pickle
from paddle.fluid.framework import _enable_legacy_dygraph
paddle.enable_static()
......@@ -74,6 +75,9 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
world_size = 2
tot_expert = n_expert * world_size
paddle.disable_static()
# Call paddle.distributed.alltoall() under legacy dygraph
_enable_legacy_dygraph()
np.random.seed(os.getpid())
local_expert_count = np.random.randint(
1, 4, size=tot_expert).astype("int64")
......
......@@ -219,9 +219,9 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOO_LOG_LEVEL"] = "TRACE"
if eager_mode:
required_envs["FLAGS_enable_eager_mode"] = "%d" % 0
else:
required_envs["FLAGS_enable_eager_mode"] = "%d" % 1
else:
required_envs["FLAGS_enable_eager_mode"] = "%d" % 0
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file,
required_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册