提交 22a5088f 编写于 作者: C chengmo

fix

上级 a96118ec
......@@ -42,15 +42,15 @@ class TDMSingleTrainer(SingleTrainer):
"single.persistables_model_path", "", namespace)
load_tree = envs.get_global_env(
"single.load_tree", False, namespace)
"tree.load_tree", False, namespace)
self.tree_layer_path = envs.get_global_env(
"single.tree_layer_path", "", namespace)
"tree.tree_layer_path", "", namespace)
self.tree_travel_path = envs.get_global_env(
"single.tree_travel_path", "", namespace)
"tree.tree_travel_path", "", namespace)
self.tree_info_path = envs.get_global_env(
"single.tree_info_path", "", namespace)
"tree.tree_info_path", "", namespace)
self.tree_emb_path = envs.get_global_env(
"single.tree_emb_path", "", namespace)
"tree.tree_emb_path", "", namespace)
save_init_model = envs.get_global_env(
"single.save_init_model", False, namespace)
......
......@@ -11,6 +11,7 @@ engines = {}
device = ["CPU", "GPU"]
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
custom_model = ['tdm']
model_name = ""
def engine_registry():
......@@ -34,7 +35,7 @@ def get_engine(args):
d_engine = engines[device]
transpiler = get_transpiler()
engine = get_custom_model_engine(args)
engine = args.engine
run_engine = d_engine[transpiler].get(engine, None)
if run_engine is None:
......@@ -43,16 +44,6 @@ def get_engine(args):
return run_engine
def get_custom_model_engine(args):
model = args.model
model_name = model.split('.')[1]
if model_name in custom_model:
engine = "_".join((model_name.upper(), args.engine))
else:
engine = args.engine
return engine
def get_transpiler():
FNULL = open(os.devnull, 'w')
cmd = ["python", "-c",
......@@ -93,8 +84,6 @@ def set_runtime_envs(cluster_envs, engine_yaml):
def get_trainer_prefix(args):
model = args.model
model_name = model.split('.')[1]
if model_name in custom_model:
return model_name.upper()
return ""
......@@ -218,6 +207,7 @@ if __name__ == "__main__":
args = parser.parse_args()
args.engine = args.engine.upper()
args.device = args.device.upper()
model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model)
engine_registry()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册