提交 8b73e3a0 编写于 作者: T tangwei

code clean

上级 7b3ab68c
from fleetrec.trainer.trainer import Trainer
class UserDefineTrainer(Trainer):
def __init__(self, config=None):
Trainer.__init__(self, config)
...@@ -34,15 +34,22 @@ class TrainerFactory(object): ...@@ -34,15 +34,22 @@ class TrainerFactory(object):
print(envs.pretty_print_envs(envs.get_global_envs())) print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer") train_mode = envs.get_global_env("train.trainer")
if train_mode == "SingleTraining": if train_mode == "SingleTraining":
trainer = SingleTrainer(yaml_path) trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining": elif train_mode == "ClusterTraining":
trainer = ClusterTrainer(yaml_path) trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTrainer": elif train_mode == "CtrTrainer":
trainer = CtrPaddleTrainer(config) trainer = CtrPaddleTrainer(config)
elif train_mode == "UserDefineTrainer":
train_location = envs.get_global_env("train.trainer.location")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
trainer = trainer_class(yaml_path)
else: else:
raise ValueError("trainer only support SingleTraining/ClusterTraining") raise ValueError("trainer only support SingleTraining/ClusterTraining")
return trainer return trainer
@staticmethod @staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册