提交 865c9a9d 编写于 作者: C chengmo

fix startup

上级 6b8c051f
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -20,13 +20,13 @@ import os
import copy
from fleetrec.core.engine.engine import Engine
from fleetrec.core.utils import envs
class LocalClusterEngine(Engine):
def start_procs(self):
worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"]
start_port = self.envs["start_port"]
ports = [self.envs["start_port"]]
logs_dir = self.envs["log_dir"]
default_env = os.environ.copy()
......@@ -36,7 +36,13 @@ class LocalClusterEngine(Engine):
current_env.pop("https_proxy", None)
procs = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
for i in range(server_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -40,7 +40,7 @@ class ClusterTrainer(TranspileTrainer):
else:
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train)
else:
......
......@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer):
def processor_register(self):
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader":
self.regist_context_processor('train_pass', self.dataset_train)
else:
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -23,7 +23,6 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f
from fleetrec.core.trainer import Trainer
from fleetrec.core.utils import envs
from fleetrec.core.utils import dataloader_instance
import fleetrec.core.din_reader as din_reader
class TranspileTrainer(Trainer):
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -15,7 +15,8 @@
import os
import copy
import sys
import socket
from contextlib import closing
global_envs = {}
......@@ -170,3 +171,12 @@ def get_platform():
return "DARWIN"
if 'Windows' in plats:
return "WINDOWS"
def find_free_port():
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
new_port = __free_port()
return new_port
......@@ -139,7 +139,7 @@ def local_cluster_engine(args):
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["start_port"] = envs.find_free_port()
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.strategy"] = "async"
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册