diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 3cfd0e5a0674ff985102aaa38d6c7840502232a1..0842a1366b3af504ddc318a1721b323fd3628bfc 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -45,7 +45,7 @@ class TrainerFactory(object): trainer_abs = trainers.get(train_mode, None) if trainer_abs is None: - if not os.path.exists(train_mode) or os.path.isfile(train_mode): + if not os.path.exists(train_mode) or not os.path.isfile(train_mode): raise ValueError("trainer {} can not be recognized".format(train_mode)) trainer_abs = train_mode train_mode = "UserDefineTrainer" diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index aed9e0a590456b1ab5dc67655cd2a642521f73fe..2aa6951e468887a3b23a462cdb5f0de41925f96b 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -21,8 +21,18 @@ global_envs = {} def set_runtime_envions(envs): assert isinstance(envs, dict) + def fatten_env_namespace(namespace_nests, local_envs): + for k, v in local_envs.items(): + if isinstance(v, dict): + nests = copy.deepcopy(namespace_nests) + nests.append(k) + fatten_env_namespace(nests, v) + else: + global_k = ".".join(namespace_nests + [k]) + os.environ[global_k] = str(v) + for k, v in envs.items(): - os.environ[k] = str(v) + fatten_env_namespace([k], v) def get_runtime_envion(key): diff --git a/fleetrec/examples/runtime.yaml b/fleetrec/examples/runtime.yaml index b639f14a28bf48605a25ea803170931459f8bdf7..70fa27c3f2cea6db19d10ffdddc2edfd7fe72983 100644 --- a/fleetrec/examples/runtime.yaml +++ b/fleetrec/examples/runtime.yaml @@ -3,10 +3,11 @@ trainer: threads: 4 # for cluster training + strategy: "async" communicator: - strategy: "async" send_queue_size: 4 min_send_grad_num_before_recv: 4 thread_pool_size: 5 max_merge_var_num: 4 + diff --git a/fleetrec/examples/user_define/__init__.py b/fleetrec/examples/user_define/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/fleetrec/examples/user_define/user_define_trainer.yaml b/fleetrec/examples/user_define/user_define_trainer.yaml deleted file mode 100644 index 973bada54bea6c2c0e929d0dd7182a2003c84712..0000000000000000000000000000000000000000 --- a/fleetrec/examples/user_define/user_define_trainer.yaml +++ /dev/null @@ -1,2 +0,0 @@ -trainer: "UserDefineTrainer" -location: "/root/FleetRec/fleetrec/examples/user_define_trainer.py" diff --git a/fleetrec/examples/user_define/user_define_trainer.py b/fleetrec/examples/user_define_trainer.py similarity index 100% rename from fleetrec/examples/user_define/user_define_trainer.py rename to fleetrec/examples/user_define_trainer.py diff --git a/fleetrec/run.py b/fleetrec/run.py index aaae280864597c20f77ec3c08e30eb32fc677b6f..6027a678a85e12a3576e18869c334cf002598b64 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -21,8 +21,9 @@ def set_runtime_envs(cluster_envs, engine_yaml): if cluster_envs is None: cluster_envs = {} + cluster_envs.update(cluster_envs) cluster_envs.update(_envs) - envs.set_runtime_envions(cluster_envs) + # envs.set_runtime_envions(cluster_envs) print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) @@ -40,7 +41,10 @@ def get_engine(engine): def single_engine(args): print("use single engine to run model: {}".format(args.model)) - single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"} + + single_envs = {} + single_envs["trainer.trainer"] = "SingleTrainer" + single_envs["trainer.threads"] = "2" set_runtime_envs(single_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) return trainer @@ -49,7 +53,8 @@ def single_engine(args): def cluster_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) - cluster_envs = {"trainer.trainer": "ClusterTrainer"} + cluster_envs = {} + cluster_envs["trainer.trainer"] = "ClusterTrainer" set_runtime_envs(cluster_envs, args.engine_extras) envs.set_runtime_envions(cluster_envs) @@ -60,7 +65,8 @@ def cluster_engine(args): def cluster_mpi_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) - cluster_envs = {"trainer.trainer": "CtrCodingTrainer"} + cluster_envs = {} + cluster_envs["trainer.trainer"] = "CtrCodingTrainer" set_runtime_envs(cluster_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) @@ -77,7 +83,9 @@ def local_cluster_engine(args): cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" cluster_envs["trainer.trainer"] = "ClusterTrainer" - cluster_envs["trainer.strategy.mode"] = "async" + cluster_envs["trainer.strategy"] = "async" + cluster_envs["trainer.threads"] = "2" + cluster_envs["CPU_NUM"] = "2" set_runtime_envs(cluster_envs, args.engine_extras)