diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index be7ad257ccb99ccc1775c0a73cfbe8443b8454cf..c69b21538b61ad207db385afa99cbbc1448a2b71 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import logging -import socket import time import os import signal @@ -27,6 +25,7 @@ from contextlib import closing import socket import warnings import six +import struct import paddle import paddle.fluid as fluid @@ -362,6 +361,10 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): def find_free_ports(num): def __free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + # Note(wangxi): Close the connection with a TCP RST instead + # of a TCP FIN, to avoid time_wait state. + s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack('ii', 1, 0)) s.bind(('', 0)) return s.getsockname()[1] @@ -376,7 +379,7 @@ def find_free_ports(num): return port_set step += 1 - if step > 100: + if step > 400: print( "can't find avilable port and use the specified static port now!" ) diff --git a/python/paddle/fluid/tests/unittests/test_launch_coverage.py b/python/paddle/fluid/tests/unittests/test_launch_coverage.py index 43613928585e77035d7d405996b3b4940d953d08..9fbf27e3c1d063452d2bc75805c45fbd2a0959d2 100644 --- a/python/paddle/fluid/tests/unittests/test_launch_coverage.py +++ b/python/paddle/fluid/tests/unittests/test_launch_coverage.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from argparse import ArgumentParser, REMAINDER from paddle.distributed.utils import _print_arguments, get_gpus, get_cluster_from_args +from paddle.distributed.fleet.launch_utils import find_free_ports def _parse_args(): @@ -115,6 +116,9 @@ class TestCoverage(unittest.TestCase): args.use_paddlecloud = True cluster, pod = get_cluster_from_args(args, "0") + def test_find_free_ports(self): + find_free_ports(2) + if __name__ == '__main__': unittest.main()