未验证 提交 0c7f3575 编写于 作者: L LiYuRio 提交者: GitHub

fix global gather (#48736)

上级 c5a45cc6
...@@ -22,7 +22,6 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main ...@@ -22,7 +22,6 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import paddle import paddle
import paddle.distributed.utils.moe_utils as moe_utils import paddle.distributed.utils.moe_utils as moe_utils
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import _enable_legacy_dygraph
paddle.enable_static() paddle.enable_static()
...@@ -61,8 +60,8 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): ...@@ -61,8 +60,8 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
endpoints = args["endpoints"].split(",") endpoints = args["endpoints"].split(",")
rank = args["trainerid"] rank = args["trainerid"]
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
nranks = 2
if args['backend'] == 'nccl': if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace( place = fluid.CUDAPlace(
...@@ -78,33 +77,37 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): ...@@ -78,33 +77,37 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
n_expert = 2 n_expert = 2
world_size = 2 world_size = 2
tot_expert = n_expert * world_size tot_expert = n_expert * world_size
paddle.disable_static()
# Call paddle.distributed.alltoall() under legacy dygraph tmp_main_prog = fluid.Program()
_enable_legacy_dygraph() with fluid.program_guard(tmp_main_prog, fluid.Program()):
local_expert_count = paddle.static.data(
name="local_expert_count", shape=[tot_expert], dtype="int64"
)
global_expert_count = []
paddle.distributed.alltoall(
paddle.split(local_expert_count, 2, axis=0), global_expert_count
)
global_expert_count = paddle.concat(global_expert_count, axis=0)
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid()) np.random.seed(os.getpid())
local_expert_count = np.random.randint(1, 4, size=tot_expert).astype( local_expert_count = np.random.randint(1, 4, size=tot_expert).astype(
"int64" "int64"
) )
local_expert_count = paddle.to_tensor(local_expert_count) (global_expert_count,) = exe.run(
global_expert_count = [] tmp_main_prog,
paddle.distributed.alltoall( feed={"local_expert_count": local_expert_count},
paddle.split(local_expert_count, 2, axis=0), global_expert_count fetch_list=[global_expert_count.name],
) )
global_expert_count = paddle.concat(global_expert_count, axis=0)
global_expert_count = global_expert_count.numpy()
local_expert_count = local_expert_count.numpy()
fwd_expert_count = sum(global_expert_count) fwd_expert_count = sum(global_expert_count)
np.random.seed(os.getpid()) np.random.seed(os.getpid())
local_input_buf = np.random.rand(fwd_expert_count, in_feat).astype( local_input_buf = np.random.rand(fwd_expert_count, in_feat).astype(
"float32" "float32"
) )
paddle.enable_static()
if args['static_mode']: if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank) result = self.get_model(train_prog, startup_prog, rank)
exe = fluid.Executor(place)
exe.run(startup_prog)
fetch_list = [] fetch_list = []
for elem in result: for elem in result:
fetch_list.append(elem.name) fetch_list.append(elem.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册