未验证 提交 5384206a 编写于 作者: Y Yan Xu 提交者: GitHub

Merge pull request #14869 from Yancey1989/fix_dist_unittest

fix dist unit test
...@@ -227,6 +227,7 @@ class TestDistBase(unittest.TestCase): ...@@ -227,6 +227,7 @@ class TestDistBase(unittest.TestCase):
def setUp(self): def setUp(self):
self._trainers = 2 self._trainers = 2
self._pservers = 2 self._pservers = 2
self._port_set = set()
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable self._python_interp = sys.executable
...@@ -242,10 +243,18 @@ class TestDistBase(unittest.TestCase): ...@@ -242,10 +243,18 @@ class TestDistBase(unittest.TestCase):
self._after_setup_config() self._after_setup_config()
def _find_free_port(self): def _find_free_port(self):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0)) s.bind(('', 0))
return s.getsockname()[1] return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def start_pserver(self, model_file, check_error_log, required_envs): def start_pserver(self, model_file, check_error_log, required_envs):
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver" ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册