diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 332381e4d5e58c167cd939b554c9d608d31ad9f8..3cfd0e5a0674ff985102aaa38d6c7840502232a1 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -50,9 +50,8 @@ class TrainerFactory(object): trainer_abs = train_mode train_mode = "UserDefineTrainer" - train_location = envs.get_global_env("train.location") train_dirname = os.path.dirname(trainer_abs) - base_name = os.path.splitext(os.path.basename(train_location))[0] + base_name = os.path.splitext(os.path.basename(trainer_abs))[0] sys.path.append(train_dirname) trainer_class = envs.lazy_instance(base_name, train_mode) trainer = trainer_class(yaml_path) diff --git a/fleetrec/core/trainers/transpiler_trainer.py b/fleetrec/core/trainers/transpiler_trainer.py index 02253573f0980192683404cb579eae80cda7dea8..ee8ea3f199799852a2b1da2740d4866d87ff7ec2 100644 --- a/fleetrec/core/trainers/transpiler_trainer.py +++ b/fleetrec/core/trainers/transpiler_trainer.py @@ -44,7 +44,7 @@ class TranspileTrainer(Trainer): reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') - pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) train_data_path = envs.get_global_env("train_data_path", None, namespace) dataset = fluid.DatasetFactory().create_dataset() diff --git a/fleetrec/examples/user_define/__init__.py b/fleetrec/examples/user_define/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fleetrec/examples/user_define_trainer.py b/fleetrec/examples/user_define/user_define_trainer.py similarity index 100% rename from fleetrec/examples/user_define_trainer.py rename to fleetrec/examples/user_define/user_define_trainer.py diff --git a/fleetrec/examples/user_define_trainer.yaml b/fleetrec/examples/user_define/user_define_trainer.yaml similarity index 100% rename from fleetrec/examples/user_define_trainer.yaml rename to fleetrec/examples/user_define/user_define_trainer.yaml diff --git a/fleetrec/run.py b/fleetrec/run.py index aad2ece38ebe7fdf9ab6cf3d5eed899d2cebb48a..aaae280864597c20f77ec3c08e30eb32fc677b6f 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -40,10 +40,8 @@ def get_engine(engine): def single_engine(args): print("use single engine to run model: {}".format(args.model)) - single_envs = {"trainer.trainer": "SingleTrainer"} - + single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"} set_runtime_envs(single_envs, args.engine_extras) - trainer = TrainerFactory.create(args.model) return trainer