From a8625aafd00195ca5151b34e8ff329e15107fff9 Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 13 May 2021 21:24:56 +0800 Subject: [PATCH] fix wait server ready (#32889) --- .../meta_optimizers/graph_execution_optimizer.py | 6 +++--- .../tests/unittests/test_fleet_graph_executor.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 4194cf13d2b..22ed3f2ac41 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -63,9 +63,9 @@ class GraphExecutionOptimizer(MetaOptimizerBase): trainer_endpoints_env = ",".join(trainer_endpoints) trainers_num = self.role_maker._worker_num() - # FIXME(wangxi): approve this. - #if trainer_id == 0: - # wait_server_ready(other_trainers) + # NOTE(wangxi): npu don't need to wait server ready + if trainer_id == 0 and not paddle.is_compiled_with_npu(): + wait_server_ready(other_trainers) if core.is_compiled_with_cuda(): comm_id_var = startup_program.global_block().create_var( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py index 05da44cd061..628f1db80d2 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py @@ -80,15 +80,17 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): cost_val = exe.run(feed=gen_data(), fetch_list=[avg_cost.name]) print("cost of step[{}] = {}".format(i, cost_val)) - proc_a = launch_func(node_func, node_a) - proc_a.start() + # rank 1 + proc_b = launch_func(node_func, node_b) + proc_b.start() + # rank 0, for wait server ready coverage # just for coverage - for key in node_b: - os.environ[key] = node_b[key] + for key in node_a: + os.environ[key] = node_a[key] node_func() - proc_a.join() + proc_b.join() if __name__ == "__main__": -- GitLab