未验证 提交 0c7f3575 编写于 作者: L LiYuRio 提交者: GitHub

fix global gather (#48736)

上级 c5a45cc6
......@@ -22,7 +22,6 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import paddle
import paddle.distributed.utils.moe_utils as moe_utils
import paddle.fluid as fluid
from paddle.fluid.framework import _enable_legacy_dygraph
paddle.enable_static()
......@@ -61,8 +60,8 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
paddle.distributed.init_parallel_env()
nranks = 2
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(
......@@ -78,33 +77,37 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
n_expert = 2
world_size = 2
tot_expert = n_expert * world_size
paddle.disable_static()
# Call paddle.distributed.alltoall() under legacy dygraph
_enable_legacy_dygraph()
np.random.seed(os.getpid())
local_expert_count = np.random.randint(1, 4, size=tot_expert).astype(
"int64"
tmp_main_prog = fluid.Program()
with fluid.program_guard(tmp_main_prog, fluid.Program()):
local_expert_count = paddle.static.data(
name="local_expert_count", shape=[tot_expert], dtype="int64"
)
local_expert_count = paddle.to_tensor(local_expert_count)
global_expert_count = []
paddle.distributed.alltoall(
paddle.split(local_expert_count, 2, axis=0), global_expert_count
)
global_expert_count = paddle.concat(global_expert_count, axis=0)
global_expert_count = global_expert_count.numpy()
local_expert_count = local_expert_count.numpy()
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid())
local_expert_count = np.random.randint(1, 4, size=tot_expert).astype(
"int64"
)
(global_expert_count,) = exe.run(
tmp_main_prog,
feed={"local_expert_count": local_expert_count},
fetch_list=[global_expert_count.name],
)
fwd_expert_count = sum(global_expert_count)
np.random.seed(os.getpid())
local_input_buf = np.random.rand(fwd_expert_count, in_feat).astype(
"float32"
)
paddle.enable_static()
if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank)
exe = fluid.Executor(place)
exe.run(startup_prog)
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册