From ae45d981181b44783c61a21d808b54cc5148dc02 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 31 May 2022 12:01:09 +0800 Subject: [PATCH] [Eager] fix collective_global_gather (#43090) * [Eager] fix collective_global_gather * fix eager_ode = 1 --- .../paddle/fluid/tests/unittests/collective_global_gather.py | 4 ++++ .../paddle/fluid/tests/unittests/test_collective_api_base.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/collective_global_gather.py b/python/paddle/fluid/tests/unittests/collective_global_gather.py index d3a6071ed04..164abe05934 100644 --- a/python/paddle/fluid/tests/unittests/collective_global_gather.py +++ b/python/paddle/fluid/tests/unittests/collective_global_gather.py @@ -23,6 +23,7 @@ import unittest import paddle.fluid.layers as layers from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main import pickle +from paddle.fluid.framework import _enable_legacy_dygraph paddle.enable_static() @@ -74,6 +75,9 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): world_size = 2 tot_expert = n_expert * world_size paddle.disable_static() + + # Call paddle.distributed.alltoall() under legacy dygraph + _enable_legacy_dygraph() np.random.seed(os.getpid()) local_expert_count = np.random.randint( 1, 4, size=tot_expert).astype("int64") diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index dbd98294726..a4e71db3d38 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -219,9 +219,9 @@ class TestDistBase(unittest.TestCase): required_envs["GLOO_LOG_LEVEL"] = "TRACE" if eager_mode: - required_envs["FLAGS_enable_eager_mode"] = "%d" % 0 - else: required_envs["FLAGS_enable_eager_mode"] = "%d" % 1 + else: + required_envs["FLAGS_enable_eager_mode"] = "%d" % 0 tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, required_envs) -- GitLab