提交 9368aec8 编写于 作者: T tangwei12

fix bug, cluster training

上级 2ebef2b7
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
train: train:
threads: 12 threads: 12
epochs: 10 epochs: 10
trainer: "LocalClusterTraining" trainer: "ClusterTraining"
container: "local"
pserver_num: 2 pserver_num: 2
trainer_num: 2 trainer_num: 2
......
...@@ -26,15 +26,11 @@ ...@@ -26,15 +26,11 @@
import os import os
import yaml
from eleps.trainer.factory import TrainerFactory from eleps.trainer.factory import TrainerFactory
if __name__ == "__main__": if __name__ == "__main__":
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml')
with open(os.path.join(abs_dir, 'ctr-dnn_train_single.yaml'), 'r') as rb: trainer = TrainerFactory.create(yaml)
global_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
trainer = TrainerFactory.create(global_config)
trainer.run() trainer.run()
...@@ -42,10 +42,9 @@ class ClusterTrainerWithDataloader(TranspileTrainer): ...@@ -42,10 +42,9 @@ class ClusterTrainerWithDataloader(TranspileTrainer):
class ClusterTrainerWithDataset(TranspileTrainer): class ClusterTrainerWithDataset(TranspileTrainer):
def processor_register(self): def processor_register(self):
role = PaddleCloudRoleMaker() role = PaddleCloudRoleMaker()
fleet.init(role) fleet.init(role)
if role.is_server(): if fleet.is_server():
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('server_pass', self.server) self.regist_context_processor('server_pass', self.server)
...@@ -74,6 +73,9 @@ class ClusterTrainerWithDataset(TranspileTrainer): ...@@ -74,6 +73,9 @@ class ClusterTrainerWithDataset(TranspileTrainer):
return strategy return strategy
def init(self, context): def init(self, context):
print("init pass")
self.model.input() self.model.input()
self.model.net() self.model.net()
self.metrics = self.model.metrics() self.metrics = self.model.metrics()
......
...@@ -35,7 +35,7 @@ from eleps.trainer.single_trainer import SingleTrainerWithDataset ...@@ -35,7 +35,7 @@ from eleps.trainer.single_trainer import SingleTrainerWithDataset
from eleps.trainer.cluster_trainer import ClusterTrainerWithDataloader from eleps.trainer.cluster_trainer import ClusterTrainerWithDataloader
from eleps.trainer.cluster_trainer import ClusterTrainerWithDataset from eleps.trainer.cluster_trainer import ClusterTrainerWithDataset
from eleps.trainer.local_engine import local_launch from eleps.trainer.local_engine import Launch
from eleps.trainer.ctr_trainer import CtrPaddleTrainer from eleps.trainer.ctr_trainer import CtrPaddleTrainer
from eleps.utils import envs from eleps.utils import envs
...@@ -91,24 +91,26 @@ class TrainerFactory(object): ...@@ -91,24 +91,26 @@ class TrainerFactory(object):
cluster_envs["start_port"] = envs.get_global_env("train.start_port") cluster_envs["start_port"] = envs.get_global_env("train.start_port")
cluster_envs["log_dir"] = envs.get_global_env("train.log_dirname") cluster_envs["log_dir"] = envs.get_global_env("train.log_dirname")
envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value")) print(envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value")))
local_launch(cluster_envs, yaml_config) launch = Launch(cluster_envs, yaml_config)
return launch
@staticmethod @staticmethod
def create(config): def create(config):
_config = None _config = None
if os.path.exists(config) and os.path.isfile(config): if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb: with open(config, 'r') as rb:
_config = yaml.load(rb.read()) _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else: else:
raise ValueError("eleps's config only support yaml") raise ValueError("eleps's config only support yaml")
envs.set_global_envs(_config) envs.set_global_envs(_config)
train_mode = envs.get_global_env("train.trainer") mode = envs.get_global_env("train.trainer")
container = envs.get_global_env("train.container")
instance = str2bool(os.getenv("CLUSTER_INSTANCE", "0")) instance = str2bool(os.getenv("CLUSTER_INSTANCE", "0"))
if train_mode == "LocalClusterTraining" and not instance: if mode == "ClusterTraining" and container == "local" and not instance:
trainer = TrainerFactory._build_engine(config) trainer = TrainerFactory._build_engine(config)
else: else:
trainer = TrainerFactory._build_trainer(_config) trainer = TrainerFactory._build_trainer(_config)
...@@ -120,4 +122,6 @@ class TrainerFactory(object): ...@@ -120,4 +122,6 @@ class TrainerFactory(object):
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 2: if len(sys.argv) != 2:
raise ValueError("need a yaml file path argv") raise ValueError("need a yaml file path argv")
TrainerFactory.create(sys.argv[1]) trainer = TrainerFactory.create(sys.argv[1])
trainer.run()
...@@ -38,7 +38,8 @@ def start_procs(args, yaml): ...@@ -38,7 +38,8 @@ def start_procs(args, yaml):
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")] user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")] user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
factory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "factory.py") factory = "eleps.trainer.factory"
cmd = [sys.executable, "-u", "-m", factory, yaml]
for i in range(server_num): for i in range(server_num):
current_env.update({ current_env.update({
...@@ -48,15 +49,14 @@ def start_procs(args, yaml): ...@@ -48,15 +49,14 @@ def start_procs(args, yaml):
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": user_endpoints_ips[i] "POD_IP": user_endpoints_ips[i]
}) })
cmd = [sys.executable, "-u", factory, yaml]
if args.log_dir is not None: if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir)) os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w") fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn) log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn) proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else: else:
proc = subprocess.Popen(cmd, env=current_env) proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc) procs.append(proc)
for i in range(worker_num): for i in range(worker_num):
...@@ -66,16 +66,14 @@ def start_procs(args, yaml): ...@@ -66,16 +66,14 @@ def start_procs(args, yaml):
"TRAINING_ROLE": "TRAINER", "TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i) "PADDLE_TRAINER_ID": str(i)
}) })
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
if args.log_dir is not None: if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir)) os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w") fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn) log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn) proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else: else:
proc = subprocess.Popen(cmd, env=current_env) proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc) procs.append(proc)
# only wait worker to finish here # only wait worker to finish here
...@@ -93,6 +91,11 @@ def start_procs(args, yaml): ...@@ -93,6 +91,11 @@ def start_procs(args, yaml):
procs[i].terminate() procs[i].terminate()
print("all parameter server are killed", file=sys.stderr) print("all parameter server are killed", file=sys.stderr)
class Launch():
def __init__(self, envs, trainer):
self.envs = envs
self.trainer = trainer
def run(self):
start_procs(self.envs, self.trainer)
def local_launch(envs, trainer):
start_procs(envs, trainer)
...@@ -48,7 +48,7 @@ def get_global_envs(): ...@@ -48,7 +48,7 @@ def get_global_envs():
return global_envs return global_envs
def pretty_print_envs(envs, header): def pretty_print_envs(envs, header=None):
spacing = 5 spacing = 5
max_k = 45 max_k = 45
max_v = 20 max_v = 20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册