提交 36941255 编写于 作者: C chengmo

Merge branch 'fix_startup' into 'develop'

fix startup & port

See merge request !6
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -20,13 +20,13 @@ import os ...@@ -20,13 +20,13 @@ import os
import copy import copy
from fleetrec.core.engine.engine import Engine from fleetrec.core.engine.engine import Engine
from fleetrec.core.utils import envs
class LocalClusterEngine(Engine): class LocalClusterEngine(Engine):
def start_procs(self): def start_procs(self):
worker_num = self.envs["worker_num"] worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"] server_num = self.envs["server_num"]
start_port = self.envs["start_port"] ports = [self.envs["start_port"]]
logs_dir = self.envs["log_dir"] logs_dir = self.envs["log_dir"]
default_env = os.environ.copy() default_env = os.environ.copy()
...@@ -36,7 +36,13 @@ class LocalClusterEngine(Engine): ...@@ -36,7 +36,13 @@ class LocalClusterEngine(Engine):
current_env.pop("https_proxy", None) current_env.pop("https_proxy", None)
procs = [] procs = []
log_fns = [] log_fns = []
ports = range(start_port, start_port + server_num, 1)
for i in range(server_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports]) user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")] user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")] user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -40,7 +40,7 @@ class ClusterTrainer(TranspileTrainer): ...@@ -40,7 +40,7 @@ class ClusterTrainer(TranspileTrainer):
else: else:
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader": if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
......
...@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer):
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader": if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
import os import os
import copy import copy
import sys import sys
import socket
from contextlib import closing
global_envs = {} global_envs = {}
...@@ -170,3 +171,12 @@ def get_platform(): ...@@ -170,3 +171,12 @@ def get_platform():
return "DARWIN" return "DARWIN"
if 'Windows' in plats: if 'Windows' in plats:
return "WINDOWS" return "WINDOWS"
def find_free_port():
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
new_port = __free_port()
return new_port
...@@ -139,7 +139,7 @@ def local_cluster_engine(args): ...@@ -139,7 +139,7 @@ def local_cluster_engine(args):
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = envs.find_free_port()
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = trainer cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.strategy"] = "async" cluster_envs["train.trainer.strategy"] = "async"
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册