提交 fcbb1cc0 编写于 作者: C chengmo

for merge

上级 cd7cb08a
...@@ -20,13 +20,14 @@ import os ...@@ -20,13 +20,14 @@ import os
import copy import copy
from fleetrec.core.engine.engine import Engine from fleetrec.core.engine.engine import Engine
from fleetrec.core.utils import envs
class LocalClusterEngine(Engine): class LocalClusterEngine(Engine):
def start_procs(self): def start_procs(self):
worker_num = self.envs["worker_num"] worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"] server_num = self.envs["server_num"]
start_port = self.envs["start_port"] ports = [self.envs["start_port"]]
logs_dir = self.envs["log_dir"] logs_dir = self.envs["log_dir"]
default_env = os.environ.copy() default_env = os.environ.copy()
...@@ -36,10 +37,19 @@ class LocalClusterEngine(Engine): ...@@ -36,10 +37,19 @@ class LocalClusterEngine(Engine):
current_env.pop("https_proxy", None) current_env.pop("https_proxy", None)
procs = [] procs = []
log_fns = [] log_fns = []
ports = range(start_port, start_port + server_num, 1)
for i in range(server_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports]) user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")] user_endpoints_ips = [x.split(":")[0]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1]
for x in user_endpoints.split(",")]
factory = "fleetrec.core.factory" factory = "fleetrec.core.factory"
cmd = [sys.executable, "-u", "-m", factory, self.trainer] cmd = [sys.executable, "-u", "-m", factory, self.trainer]
...@@ -56,7 +66,8 @@ class LocalClusterEngine(Engine): ...@@ -56,7 +66,8 @@ class LocalClusterEngine(Engine):
os.system("mkdir -p {}".format(logs_dir)) os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w") fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn) log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd()) proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
procs.append(proc) procs.append(proc)
for i in range(worker_num): for i in range(worker_num):
...@@ -70,7 +81,8 @@ class LocalClusterEngine(Engine): ...@@ -70,7 +81,8 @@ class LocalClusterEngine(Engine):
os.system("mkdir -p {}".format(logs_dir)) os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w") fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn) log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd()) proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
procs.append(proc) procs.append(proc)
# only wait worker to finish here # only wait worker to finish here
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import copy import copy
import sys import sys
import socket
global_envs = {} global_envs = {}
...@@ -78,7 +78,8 @@ def get_global_env(env_name, default_value=None, namespace=None): ...@@ -78,7 +78,8 @@ def get_global_env(env_name, default_value=None, namespace=None):
""" """
get os environment value get os environment value
""" """
_env_name = env_name if namespace is None else ".".join([namespace, env_name]) _env_name = env_name if namespace is None else ".".join(
[namespace, env_name])
return global_envs.get(_env_name, default_value) return global_envs.get(_env_name, default_value)
...@@ -146,7 +147,8 @@ def pretty_print_envs(envs, header=None): ...@@ -146,7 +147,8 @@ def pretty_print_envs(envs, header=None):
def lazy_instance_by_package(package, class_name): def lazy_instance_by_package(package, class_name):
models = get_global_env("train.model.models") models = get_global_env("train.model.models")
model_package = __import__(package, globals(), locals(), package.split(".")) model_package = __import__(
package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
...@@ -156,7 +158,8 @@ def lazy_instance_by_fliename(abs, class_name): ...@@ -156,7 +158,8 @@ def lazy_instance_by_fliename(abs, class_name):
sys.path.append(dirname) sys.path.append(dirname)
package = os.path.splitext(os.path.basename(abs))[0] package = os.path.splitext(os.path.basename(abs))[0]
model_package = __import__(package, globals(), locals(), package.split(".")) model_package = __import__(
package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
...@@ -170,3 +173,13 @@ def get_platform(): ...@@ -170,3 +173,13 @@ def get_platform():
return "DARWIN" return "DARWIN"
if 'Windows' in plats: if 'Windows' in plats:
return "WINDOWS" return "WINDOWS"
def find_free_port():
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
new_port = __free_port()
return new_port
...@@ -97,6 +97,28 @@ python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e local_cluster ...@@ -97,6 +97,28 @@ python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e local_cluster
python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e cluster python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e cluster
``` ```
<h2 align="center">支持模型列表</h2>
| 方向 | 模型 | 单机CPU训练 | 单机GPU训练 | 分布式CPU训练 | 大规模稀疏 | 分布式GPU训练 | 自定义数据集 |
| :------: | :--------------------: | :---------: | :---------: | :-----------: | :--------: | :-----------: | :----------: |
| 内容理解 | [Text-Classifcation]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 内容理解 | [TagSpace]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [Word2Vec]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [TDM]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [SSR]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 召回 | [Gru4Rec]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 排序 | [CTR-Dnn]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [DeepFm]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [ListWise]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [DSSM]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [Multiview-Simnet]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 融合 | [MMOE]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 融合 | [ESMM]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 融合 | [ESMM]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
<h2 align="center">文档</h2> <h2 align="center">文档</h2>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册