From 22a5088ffb0441c9bbf6e00b8e0a443de772b393 Mon Sep 17 00:00:00 2001 From: chengmo Date: Wed, 6 May 2020 23:52:11 +0800 Subject: [PATCH] fix --- fleet_rec/core/trainers/tdm_single_trainer.py | 10 +++++----- fleet_rec/run.py | 16 +++------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/fleet_rec/core/trainers/tdm_single_trainer.py b/fleet_rec/core/trainers/tdm_single_trainer.py index 8258db60..6d010999 100644 --- a/fleet_rec/core/trainers/tdm_single_trainer.py +++ b/fleet_rec/core/trainers/tdm_single_trainer.py @@ -42,15 +42,15 @@ class TDMSingleTrainer(SingleTrainer): "single.persistables_model_path", "", namespace) load_tree = envs.get_global_env( - "single.load_tree", False, namespace) + "tree.load_tree", False, namespace) self.tree_layer_path = envs.get_global_env( - "single.tree_layer_path", "", namespace) + "tree.tree_layer_path", "", namespace) self.tree_travel_path = envs.get_global_env( - "single.tree_travel_path", "", namespace) + "tree.tree_travel_path", "", namespace) self.tree_info_path = envs.get_global_env( - "single.tree_info_path", "", namespace) + "tree.tree_info_path", "", namespace) self.tree_emb_path = envs.get_global_env( - "single.tree_emb_path", "", namespace) + "tree.tree_emb_path", "", namespace) save_init_model = envs.get_global_env( "single.save_init_model", False, namespace) diff --git a/fleet_rec/run.py b/fleet_rec/run.py index b10428c8..e4727d59 100644 --- a/fleet_rec/run.py +++ b/fleet_rec/run.py @@ -11,6 +11,7 @@ engines = {} device = ["CPU", "GPU"] clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] custom_model = ['tdm'] +model_name = "" def engine_registry(): @@ -34,7 +35,7 @@ def get_engine(args): d_engine = engines[device] transpiler = get_transpiler() - engine = get_custom_model_engine(args) + engine = args.engine run_engine = d_engine[transpiler].get(engine, None) if run_engine is None: @@ -43,16 +44,6 @@ def get_engine(args): return run_engine -def get_custom_model_engine(args): - model = args.model - model_name = model.split('.')[1] - if model_name in custom_model: - engine = "_".join((model_name.upper(), args.engine)) - else: - engine = args.engine - return engine - - def get_transpiler(): FNULL = open(os.devnull, 'w') cmd = ["python", "-c", @@ -93,8 +84,6 @@ def set_runtime_envs(cluster_envs, engine_yaml): def get_trainer_prefix(args): - model = args.model - model_name = model.split('.')[1] if model_name in custom_model: return model_name.upper() return "" @@ -218,6 +207,7 @@ if __name__ == "__main__": args = parser.parse_args() args.engine = args.engine.upper() args.device = args.device.upper() + model_name = args.model.split('.')[-1] args.model = get_abs_model(args.model) engine_registry() -- GitLab