提交 22a5088f 编写于 作者: C chengmo

fix

上级 a96118ec
...@@ -42,15 +42,15 @@ class TDMSingleTrainer(SingleTrainer): ...@@ -42,15 +42,15 @@ class TDMSingleTrainer(SingleTrainer):
"single.persistables_model_path", "", namespace) "single.persistables_model_path", "", namespace)
load_tree = envs.get_global_env( load_tree = envs.get_global_env(
"single.load_tree", False, namespace) "tree.load_tree", False, namespace)
self.tree_layer_path = envs.get_global_env( self.tree_layer_path = envs.get_global_env(
"single.tree_layer_path", "", namespace) "tree.tree_layer_path", "", namespace)
self.tree_travel_path = envs.get_global_env( self.tree_travel_path = envs.get_global_env(
"single.tree_travel_path", "", namespace) "tree.tree_travel_path", "", namespace)
self.tree_info_path = envs.get_global_env( self.tree_info_path = envs.get_global_env(
"single.tree_info_path", "", namespace) "tree.tree_info_path", "", namespace)
self.tree_emb_path = envs.get_global_env( self.tree_emb_path = envs.get_global_env(
"single.tree_emb_path", "", namespace) "tree.tree_emb_path", "", namespace)
save_init_model = envs.get_global_env( save_init_model = envs.get_global_env(
"single.save_init_model", False, namespace) "single.save_init_model", False, namespace)
......
...@@ -11,6 +11,7 @@ engines = {} ...@@ -11,6 +11,7 @@ engines = {}
device = ["CPU", "GPU"] device = ["CPU", "GPU"]
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
custom_model = ['tdm'] custom_model = ['tdm']
model_name = ""
def engine_registry(): def engine_registry():
...@@ -34,7 +35,7 @@ def get_engine(args): ...@@ -34,7 +35,7 @@ def get_engine(args):
d_engine = engines[device] d_engine = engines[device]
transpiler = get_transpiler() transpiler = get_transpiler()
engine = get_custom_model_engine(args) engine = args.engine
run_engine = d_engine[transpiler].get(engine, None) run_engine = d_engine[transpiler].get(engine, None)
if run_engine is None: if run_engine is None:
...@@ -43,16 +44,6 @@ def get_engine(args): ...@@ -43,16 +44,6 @@ def get_engine(args):
return run_engine return run_engine
def get_custom_model_engine(args):
model = args.model
model_name = model.split('.')[1]
if model_name in custom_model:
engine = "_".join((model_name.upper(), args.engine))
else:
engine = args.engine
return engine
def get_transpiler(): def get_transpiler():
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
cmd = ["python", "-c", cmd = ["python", "-c",
...@@ -93,8 +84,6 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -93,8 +84,6 @@ def set_runtime_envs(cluster_envs, engine_yaml):
def get_trainer_prefix(args): def get_trainer_prefix(args):
model = args.model
model_name = model.split('.')[1]
if model_name in custom_model: if model_name in custom_model:
return model_name.upper() return model_name.upper()
return "" return ""
...@@ -218,6 +207,7 @@ if __name__ == "__main__": ...@@ -218,6 +207,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
args.engine = args.engine.upper() args.engine = args.engine.upper()
args.device = args.device.upper() args.device = args.device.upper()
model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model) args.model = get_abs_model(args.model)
engine_registry() engine_registry()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册