diff --git a/python/paddle/fluid/tests/unittests/collective/collective_global_gather.py b/python/paddle/fluid/tests/unittests/collective/collective_global_gather.py index 6de3183eb489db9bf38d63159751d1d52815a33a..1dc853865b2ead544b69cf50ec421a1a205c930c 100644 --- a/python/paddle/fluid/tests/unittests/collective/collective_global_gather.py +++ b/python/paddle/fluid/tests/unittests/collective/collective_global_gather.py @@ -22,7 +22,6 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main import paddle import paddle.distributed.utils.moe_utils as moe_utils import paddle.fluid as fluid -from paddle.fluid.framework import _enable_legacy_dygraph paddle.enable_static() @@ -61,8 +60,8 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): endpoints = args["endpoints"].split(",") rank = args["trainerid"] current_endpoint = args["currentendpoint"] - nranks = 2 paddle.distributed.init_parallel_env() + nranks = 2 if args['backend'] == 'nccl': device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace( @@ -78,33 +77,37 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): n_expert = 2 world_size = 2 tot_expert = n_expert * world_size - paddle.disable_static() - # Call paddle.distributed.alltoall() under legacy dygraph - _enable_legacy_dygraph() + tmp_main_prog = fluid.Program() + with fluid.program_guard(tmp_main_prog, fluid.Program()): + local_expert_count = paddle.static.data( + name="local_expert_count", shape=[tot_expert], dtype="int64" + ) + global_expert_count = [] + paddle.distributed.alltoall( + paddle.split(local_expert_count, 2, axis=0), global_expert_count + ) + global_expert_count = paddle.concat(global_expert_count, axis=0) + exe = fluid.Executor(place) + exe.run(startup_prog) np.random.seed(os.getpid()) local_expert_count = np.random.randint(1, 4, size=tot_expert).astype( "int64" ) - local_expert_count = paddle.to_tensor(local_expert_count) - global_expert_count = [] - paddle.distributed.alltoall( - paddle.split(local_expert_count, 2, axis=0), global_expert_count + (global_expert_count,) = exe.run( + tmp_main_prog, + feed={"local_expert_count": local_expert_count}, + fetch_list=[global_expert_count.name], ) - global_expert_count = paddle.concat(global_expert_count, axis=0) - global_expert_count = global_expert_count.numpy() - local_expert_count = local_expert_count.numpy() + fwd_expert_count = sum(global_expert_count) np.random.seed(os.getpid()) local_input_buf = np.random.rand(fwd_expert_count, in_feat).astype( "float32" ) - paddle.enable_static() if args['static_mode']: result = self.get_model(train_prog, startup_prog, rank) - exe = fluid.Executor(place) - exe.run(startup_prog) fetch_list = [] for elem in result: fetch_list.append(elem.name)