提交 36941255 编写于 作者: C chengmo

Merge branch 'fix_startup' into 'develop'

fix startup & port

See merge request !6
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -20,13 +20,13 @@ import os
import copy
from fleetrec.core.engine.engine import Engine
from fleetrec.core.utils import envs
class LocalClusterEngine(Engine):
def start_procs(self):
worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"]
start_port = self.envs["start_port"]
ports = [self.envs["start_port"]]
logs_dir = self.envs["log_dir"]
default_env = os.environ.copy()
......@@ -36,7 +36,13 @@ class LocalClusterEngine(Engine):
current_env.pop("https_proxy", None)
procs = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
for i in range(server_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -40,7 +40,7 @@ class ClusterTrainer(TranspileTrainer):
else:
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train)
else:
......
......@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer):
def processor_register(self):
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train)
else:
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -15,7 +15,8 @@
import os
import copy
import sys
import socket
from contextlib import closing
global_envs = {}
......@@ -170,3 +171,12 @@ def get_platform():
return "DARWIN"
if 'Windows' in plats:
return "WINDOWS"
def find_free_port():
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
new_port = __free_port()
return new_port
......@@ -139,7 +139,7 @@ def local_cluster_engine(args):
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["start_port"] = envs.find_free_port()
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.strategy"] = "async"
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册