提交 8b73e3a0 编写于 作者: T tangwei

code clean

上级 7b3ab68c
from fleetrec.trainer.trainer import Trainer
class UserDefineTrainer(Trainer):
def __init__(self, config=None):
Trainer.__init__(self, config)
......@@ -34,15 +34,22 @@ class TrainerFactory(object):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer")
if train_mode == "SingleTraining":
trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining":
trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTrainer":
trainer = CtrPaddleTrainer(config)
elif train_mode == "UserDefineTrainer":
train_location = envs.get_global_env("train.trainer.location")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
trainer = trainer_class(yaml_path)
else:
raise ValueError("trainer only support SingleTraining/ClusterTraining")
return trainer
@staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册