提交 865c9a9d 编写于 作者: C chengmo

fix startup

上级 6b8c051f
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -20,13 +20,13 @@ import os ...@@ -20,13 +20,13 @@ import os
import copy import copy
from fleetrec.core.engine.engine import Engine from fleetrec.core.engine.engine import Engine
from fleetrec.core.utils import envs
class LocalClusterEngine(Engine): class LocalClusterEngine(Engine):
def start_procs(self): def start_procs(self):
worker_num = self.envs["worker_num"] worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"] server_num = self.envs["server_num"]
start_port = self.envs["start_port"] ports = [self.envs["start_port"]]
logs_dir = self.envs["log_dir"] logs_dir = self.envs["log_dir"]
default_env = os.environ.copy() default_env = os.environ.copy()
...@@ -36,7 +36,13 @@ class LocalClusterEngine(Engine): ...@@ -36,7 +36,13 @@ class LocalClusterEngine(Engine):
current_env.pop("https_proxy", None) current_env.pop("https_proxy", None)
procs = [] procs = []
log_fns = [] log_fns = []
ports = range(start_port, start_port + server_num, 1)
for i in range(server_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports]) user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")] user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")] user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -40,7 +40,7 @@ class ClusterTrainer(TranspileTrainer): ...@@ -40,7 +40,7 @@ class ClusterTrainer(TranspileTrainer):
else: else:
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader": if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
......
...@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer):
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader": if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -23,7 +23,6 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f ...@@ -23,7 +23,6 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f
from fleetrec.core.trainer import Trainer from fleetrec.core.trainer import Trainer
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
from fleetrec.core.utils import dataloader_instance from fleetrec.core.utils import dataloader_instance
import fleetrec.core.din_reader as din_reader
class TranspileTrainer(Trainer): class TranspileTrainer(Trainer):
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
import os import os
import copy import copy
import sys import sys
import socket
from contextlib import closing
global_envs = {} global_envs = {}
...@@ -170,3 +171,12 @@ def get_platform(): ...@@ -170,3 +171,12 @@ def get_platform():
return "DARWIN" return "DARWIN"
if 'Windows' in plats: if 'Windows' in plats:
return "WINDOWS" return "WINDOWS"
def find_free_port():
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
new_port = __free_port()
return new_port
...@@ -139,7 +139,7 @@ def local_cluster_engine(args): ...@@ -139,7 +139,7 @@ def local_cluster_engine(args):
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = envs.find_free_port()
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = trainer cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.strategy"] = "async" cluster_envs["train.trainer.strategy"] = "async"
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册