From 42347db7390bf3a14d80ffd01de5679c20513204 Mon Sep 17 00:00:00 2001 From: tangwei Date: Mon, 20 Apr 2020 16:42:23 +0800 Subject: [PATCH] fix bug --- fleetrec/core/factory.py | 2 +- fleetrec/core/utils/envs.py | 12 +++++++++++- fleetrec/examples/runtime.yaml | 3 ++- fleetrec/examples/user_define/__init__.py | 0 .../user_define/user_define_trainer.yaml | 2 -- .../{user_define => }/user_define_trainer.py | 0 fleetrec/run.py | 18 +++++++++++++----- 7 files changed, 27 insertions(+), 10 deletions(-) delete mode 100644 fleetrec/examples/user_define/__init__.py delete mode 100644 fleetrec/examples/user_define/user_define_trainer.yaml rename fleetrec/examples/{user_define => }/user_define_trainer.py (100%) diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 3cfd0e5a..0842a136 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -45,7 +45,7 @@ class TrainerFactory(object): trainer_abs = trainers.get(train_mode, None) if trainer_abs is None: - if not os.path.exists(train_mode) or os.path.isfile(train_mode): + if not os.path.exists(train_mode) or not os.path.isfile(train_mode): raise ValueError("trainer {} can not be recognized".format(train_mode)) trainer_abs = train_mode train_mode = "UserDefineTrainer" diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index aed9e0a5..2aa6951e 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -21,8 +21,18 @@ global_envs = {} def set_runtime_envions(envs): assert isinstance(envs, dict) + def fatten_env_namespace(namespace_nests, local_envs): + for k, v in local_envs.items(): + if isinstance(v, dict): + nests = copy.deepcopy(namespace_nests) + nests.append(k) + fatten_env_namespace(nests, v) + else: + global_k = ".".join(namespace_nests + [k]) + os.environ[global_k] = str(v) + for k, v in envs.items(): - os.environ[k] = str(v) + fatten_env_namespace([k], v) def get_runtime_envion(key): diff --git a/fleetrec/examples/runtime.yaml b/fleetrec/examples/runtime.yaml index b639f14a..70fa27c3 100644 --- a/fleetrec/examples/runtime.yaml +++ b/fleetrec/examples/runtime.yaml @@ -3,10 +3,11 @@ trainer: threads: 4 # for cluster training + strategy: "async" communicator: - strategy: "async" send_queue_size: 4 min_send_grad_num_before_recv: 4 thread_pool_size: 5 max_merge_var_num: 4 + diff --git a/fleetrec/examples/user_define/__init__.py b/fleetrec/examples/user_define/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fleetrec/examples/user_define/user_define_trainer.yaml b/fleetrec/examples/user_define/user_define_trainer.yaml deleted file mode 100644 index 973bada5..00000000 --- a/fleetrec/examples/user_define/user_define_trainer.yaml +++ /dev/null @@ -1,2 +0,0 @@ -trainer: "UserDefineTrainer" -location: "/root/FleetRec/fleetrec/examples/user_define_trainer.py" diff --git a/fleetrec/examples/user_define/user_define_trainer.py b/fleetrec/examples/user_define_trainer.py similarity index 100% rename from fleetrec/examples/user_define/user_define_trainer.py rename to fleetrec/examples/user_define_trainer.py diff --git a/fleetrec/run.py b/fleetrec/run.py index aaae2808..6027a678 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -21,8 +21,9 @@ def set_runtime_envs(cluster_envs, engine_yaml): if cluster_envs is None: cluster_envs = {} + cluster_envs.update(cluster_envs) cluster_envs.update(_envs) - envs.set_runtime_envions(cluster_envs) + # envs.set_runtime_envions(cluster_envs) print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) @@ -40,7 +41,10 @@ def get_engine(engine): def single_engine(args): print("use single engine to run model: {}".format(args.model)) - single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"} + + single_envs = {} + single_envs["trainer.trainer"] = "SingleTrainer" + single_envs["trainer.threads"] = "2" set_runtime_envs(single_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) return trainer @@ -49,7 +53,8 @@ def single_engine(args): def cluster_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) - cluster_envs = {"trainer.trainer": "ClusterTrainer"} + cluster_envs = {} + cluster_envs["trainer.trainer"] = "ClusterTrainer" set_runtime_envs(cluster_envs, args.engine_extras) envs.set_runtime_envions(cluster_envs) @@ -60,7 +65,8 @@ def cluster_engine(args): def cluster_mpi_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) - cluster_envs = {"trainer.trainer": "CtrCodingTrainer"} + cluster_envs = {} + cluster_envs["trainer.trainer"] = "CtrCodingTrainer" set_runtime_envs(cluster_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) @@ -77,7 +83,9 @@ def local_cluster_engine(args): cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" cluster_envs["trainer.trainer"] = "ClusterTrainer" - cluster_envs["trainer.strategy.mode"] = "async" + cluster_envs["trainer.strategy"] = "async" + cluster_envs["trainer.threads"] = "2" + cluster_envs["CPU_NUM"] = "2" set_runtime_envs(cluster_envs, args.engine_extras) -- GitLab