From 1096f044a9d72e1523daa0cc7d8d5295e0e97bf5 Mon Sep 17 00:00:00 2001 From: tangwei Date: Mon, 20 Apr 2020 16:09:39 +0800 Subject: [PATCH] fix readme --- fleetrec/core/factory.py | 3 +-- fleetrec/core/trainers/transpiler_trainer.py | 2 +- fleetrec/examples/user_define/__init__.py | 0 fleetrec/examples/{ => user_define}/user_define_trainer.py | 0 fleetrec/examples/{ => user_define}/user_define_trainer.yaml | 0 fleetrec/run.py | 4 +--- 6 files changed, 3 insertions(+), 6 deletions(-) create mode 100644 fleetrec/examples/user_define/__init__.py rename fleetrec/examples/{ => user_define}/user_define_trainer.py (100%) rename fleetrec/examples/{ => user_define}/user_define_trainer.yaml (100%) diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 332381e4..3cfd0e5a 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -50,9 +50,8 @@ class TrainerFactory(object): trainer_abs = train_mode train_mode = "UserDefineTrainer" - train_location = envs.get_global_env("train.location") train_dirname = os.path.dirname(trainer_abs) - base_name = os.path.splitext(os.path.basename(train_location))[0] + base_name = os.path.splitext(os.path.basename(trainer_abs))[0] sys.path.append(train_dirname) trainer_class = envs.lazy_instance(base_name, train_mode) trainer = trainer_class(yaml_path) diff --git a/fleetrec/core/trainers/transpiler_trainer.py b/fleetrec/core/trainers/transpiler_trainer.py index 02253573..ee8ea3f1 100644 --- a/fleetrec/core/trainers/transpiler_trainer.py +++ b/fleetrec/core/trainers/transpiler_trainer.py @@ -44,7 +44,7 @@ class TranspileTrainer(Trainer): reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') - pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) train_data_path = envs.get_global_env("train_data_path", None, namespace) dataset = fluid.DatasetFactory().create_dataset() diff --git a/fleetrec/examples/user_define/__init__.py b/fleetrec/examples/user_define/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fleetrec/examples/user_define_trainer.py b/fleetrec/examples/user_define/user_define_trainer.py similarity index 100% rename from fleetrec/examples/user_define_trainer.py rename to fleetrec/examples/user_define/user_define_trainer.py diff --git a/fleetrec/examples/user_define_trainer.yaml b/fleetrec/examples/user_define/user_define_trainer.yaml similarity index 100% rename from fleetrec/examples/user_define_trainer.yaml rename to fleetrec/examples/user_define/user_define_trainer.yaml diff --git a/fleetrec/run.py b/fleetrec/run.py index aad2ece3..aaae2808 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -40,10 +40,8 @@ def get_engine(engine): def single_engine(args): print("use single engine to run model: {}".format(args.model)) - single_envs = {"trainer.trainer": "SingleTrainer"} - + single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"} set_runtime_envs(single_envs, args.engine_extras) - trainer = TrainerFactory.create(args.model) return trainer -- GitLab