From 9368aec87fb31d8a86c48be00ed29a1faab7a5f5 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 9 Apr 2020 07:36:02 +0000 Subject: [PATCH] fix bug, cluster training --- examples/ctr-dnn_train_cluster.yaml | 3 ++- examples/train.py | 8 ++------ trainer/cluster_trainer.py | 6 ++++-- trainer/factory.py | 18 +++++++++++------- trainer/local_engine.py | 27 +++++++++++++++------------ utils/envs.py | 2 +- 6 files changed, 35 insertions(+), 29 deletions(-) diff --git a/examples/ctr-dnn_train_cluster.yaml b/examples/ctr-dnn_train_cluster.yaml index 281a0a77..7239e582 100644 --- a/examples/ctr-dnn_train_cluster.yaml +++ b/examples/ctr-dnn_train_cluster.yaml @@ -27,7 +27,8 @@ train: threads: 12 epochs: 10 - trainer: "LocalClusterTraining" + trainer: "ClusterTraining" + container: "local" pserver_num: 2 trainer_num: 2 diff --git a/examples/train.py b/examples/train.py index 48d65d4c..ac7fdd30 100644 --- a/examples/train.py +++ b/examples/train.py @@ -26,15 +26,11 @@ import os -import yaml from eleps.trainer.factory import TrainerFactory if __name__ == "__main__": abs_dir = os.path.dirname(os.path.abspath(__file__)) - - with open(os.path.join(abs_dir, 'ctr-dnn_train_single.yaml'), 'r') as rb: - global_config = yaml.load(rb.read(), Loader=yaml.FullLoader) - - trainer = TrainerFactory.create(global_config) + yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml') + trainer = TrainerFactory.create(yaml) trainer.run() diff --git a/trainer/cluster_trainer.py b/trainer/cluster_trainer.py index b8f8ded3..65f9efa3 100644 --- a/trainer/cluster_trainer.py +++ b/trainer/cluster_trainer.py @@ -42,10 +42,9 @@ class ClusterTrainerWithDataloader(TranspileTrainer): class ClusterTrainerWithDataset(TranspileTrainer): def processor_register(self): role = PaddleCloudRoleMaker() - fleet.init(role) - if role.is_server(): + if fleet.is_server(): self.regist_context_processor('uninit', self.instance) self.regist_context_processor('init_pass', self.init) self.regist_context_processor('server_pass', self.server) @@ -74,6 +73,9 @@ class ClusterTrainerWithDataset(TranspileTrainer): return strategy def init(self, context): + + print("init pass") + self.model.input() self.model.net() self.metrics = self.model.metrics() diff --git a/trainer/factory.py b/trainer/factory.py index b7ba81cb..dc590ce5 100644 --- a/trainer/factory.py +++ b/trainer/factory.py @@ -35,7 +35,7 @@ from eleps.trainer.single_trainer import SingleTrainerWithDataset from eleps.trainer.cluster_trainer import ClusterTrainerWithDataloader from eleps.trainer.cluster_trainer import ClusterTrainerWithDataset -from eleps.trainer.local_engine import local_launch +from eleps.trainer.local_engine import Launch from eleps.trainer.ctr_trainer import CtrPaddleTrainer from eleps.utils import envs @@ -91,24 +91,26 @@ class TrainerFactory(object): cluster_envs["start_port"] = envs.get_global_env("train.start_port") cluster_envs["log_dir"] = envs.get_global_env("train.log_dirname") - envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value")) + print(envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value"))) - local_launch(cluster_envs, yaml_config) + launch = Launch(cluster_envs, yaml_config) + return launch @staticmethod def create(config): _config = None if os.path.exists(config) and os.path.isfile(config): with open(config, 'r') as rb: - _config = yaml.load(rb.read()) + _config = yaml.load(rb.read(), Loader=yaml.FullLoader) else: raise ValueError("eleps's config only support yaml") envs.set_global_envs(_config) - train_mode = envs.get_global_env("train.trainer") + mode = envs.get_global_env("train.trainer") + container = envs.get_global_env("train.container") instance = str2bool(os.getenv("CLUSTER_INSTANCE", "0")) - if train_mode == "LocalClusterTraining" and not instance: + if mode == "ClusterTraining" and container == "local" and not instance: trainer = TrainerFactory._build_engine(config) else: trainer = TrainerFactory._build_trainer(_config) @@ -120,4 +122,6 @@ class TrainerFactory(object): if __name__ == "__main__": if len(sys.argv) != 2: raise ValueError("need a yaml file path argv") - TrainerFactory.create(sys.argv[1]) + trainer = TrainerFactory.create(sys.argv[1]) + trainer.run() + diff --git a/trainer/local_engine.py b/trainer/local_engine.py index 0e171a5b..07be9ac4 100644 --- a/trainer/local_engine.py +++ b/trainer/local_engine.py @@ -38,7 +38,8 @@ def start_procs(args, yaml): user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")] user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")] - factory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "factory.py") + factory = "eleps.trainer.factory" + cmd = [sys.executable, "-u", "-m", factory, yaml] for i in range(server_num): current_env.update({ @@ -48,15 +49,14 @@ def start_procs(args, yaml): "PADDLE_TRAINERS_NUM": str(worker_num), "POD_IP": user_endpoints_ips[i] }) - cmd = [sys.executable, "-u", factory, yaml] - if args.log_dir is not None: + if logs_dir is not None: os.system("mkdir -p {}".format(logs_dir)) fn = open("%s/server.%d" % (logs_dir, i), "w") log_fns.append(fn) - proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn) + proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd()) else: - proc = subprocess.Popen(cmd, env=current_env) + proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd()) procs.append(proc) for i in range(worker_num): @@ -66,16 +66,14 @@ def start_procs(args, yaml): "TRAINING_ROLE": "TRAINER", "PADDLE_TRAINER_ID": str(i) }) - cmd = [sys.executable, "-u", args.training_script - ] + args.training_script_args - if args.log_dir is not None: + if logs_dir is not None: os.system("mkdir -p {}".format(logs_dir)) fn = open("%s/worker.%d" % (logs_dir, i), "w") log_fns.append(fn) - proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn) + proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd()) else: - proc = subprocess.Popen(cmd, env=current_env) + proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd()) procs.append(proc) # only wait worker to finish here @@ -93,6 +91,11 @@ def start_procs(args, yaml): procs[i].terminate() print("all parameter server are killed", file=sys.stderr) +class Launch(): + def __init__(self, envs, trainer): + self.envs = envs + self.trainer = trainer + + def run(self): + start_procs(self.envs, self.trainer) -def local_launch(envs, trainer): - start_procs(envs, trainer) diff --git a/utils/envs.py b/utils/envs.py index 66fc9820..dce663b1 100644 --- a/utils/envs.py +++ b/utils/envs.py @@ -48,7 +48,7 @@ def get_global_envs(): return global_envs -def pretty_print_envs(envs, header): +def pretty_print_envs(envs, header=None): spacing = 5 max_k = 45 max_v = 20 -- GitLab