diff --git a/python/paddle/fluid/tests/unittests/test_communicator_geo.py b/python/paddle/fluid/tests/unittests/test_communicator_geo.py index ea59e070cbd51da440d81a3eb2236edb38385f2b..d9c6406422277c72f18bde341855f66dff7f3555 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_geo.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_geo.py @@ -28,6 +28,8 @@ import paddle.fluid as fluid import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet as fleet +from paddle.distributed.utils import find_free_ports + paddle.enable_static() @@ -101,12 +103,9 @@ class TestCommunicatorGeoEnd2End(unittest.TestCase): os.environ["PADDLE_PSERVER_NUMS"] = "1" os.environ["PADDLE_TRAINERS_NUM"] = "1" - os.environ["POD_IP"] = "127.0.0.1" - os.environ["PADDLE_PORT"] = "36001" os.environ["PADDLE_TRAINER_ID"] = "0" os.environ["PADDLE_TRAINERS_NUM"] = "1" - os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ - "127.0.0.1:36001" + os.environ["POD_IP"] = "127.0.0.1" role = role_maker.PaddleCloudRoleMaker() @@ -150,8 +149,6 @@ class RunServer(TestCommunicatorGeoEnd2End): pass os.environ["TRAINING_ROLE"] = "PSERVER" -os.environ["http_proxy"] = "" -os.environ["https_proxy"] = "" half_run_server = RunServer() half_run_server.run_ut() @@ -160,9 +157,12 @@ half_run_server.run_ut() server_file = "run_server_for_communicator_geo.py" with open(server_file, "w") as wb: wb.write(run_server_cmd) + + port = find_free_ports(1).pop() + os.environ["TRAINING_ROLE"] = "PSERVER" - os.environ["http_proxy"] = "" - os.environ["https_proxy"] = "" + os.environ["PADDLE_PORT"] = str(port) + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:{}".format(port) _python = sys.executable @@ -173,17 +173,14 @@ half_run_server.run_ut() stdout=subprocess.PIPE, stderr=subprocess.PIPE) - outs, errs = ps_proc.communicate(timeout=15) - - time.sleep(1) + time.sleep(5) os.environ["TRAINING_ROLE"] = "TRAINER" - os.environ["http_proxy"] = "" - os.environ["https_proxy"] = "" self.run_ut() ps_proc.kill() ps_proc.wait() + outs, errs = ps_proc.communicate() if os.path.exists(server_file): os.remove(server_file)