提交 1096f044 编写于 作者: T tangwei

fix readme

上级 e7986cb4
...@@ -50,9 +50,8 @@ class TrainerFactory(object): ...@@ -50,9 +50,8 @@ class TrainerFactory(object):
trainer_abs = train_mode trainer_abs = train_mode
train_mode = "UserDefineTrainer" train_mode = "UserDefineTrainer"
train_location = envs.get_global_env("train.location")
train_dirname = os.path.dirname(trainer_abs) train_dirname = os.path.dirname(trainer_abs)
base_name = os.path.splitext(os.path.basename(train_location))[0] base_name = os.path.splitext(os.path.basename(trainer_abs))[0]
sys.path.append(train_dirname) sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, train_mode) trainer_class = envs.lazy_instance(base_name, train_mode)
trainer = trainer_class(yaml_path) trainer = trainer_class(yaml_path)
......
...@@ -44,7 +44,7 @@ class TranspileTrainer(Trainer): ...@@ -44,7 +44,7 @@ class TranspileTrainer(Trainer):
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') reader = os.path.join(abs_dir, '../utils', 'reader_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace) train_data_path = envs.get_global_env("train_data_path", None, namespace)
dataset = fluid.DatasetFactory().create_dataset() dataset = fluid.DatasetFactory().create_dataset()
......
...@@ -40,10 +40,8 @@ def get_engine(engine): ...@@ -40,10 +40,8 @@ def get_engine(engine):
def single_engine(args): def single_engine(args):
print("use single engine to run model: {}".format(args.model)) print("use single engine to run model: {}".format(args.model))
single_envs = {"trainer.trainer": "SingleTrainer"} single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"}
set_runtime_envs(single_envs, args.engine_extras) set_runtime_envs(single_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册