diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index d39d48f31c268a65f6691663f45c6fe09308c6b7..a09a87b3bae91344d14a032dda0567bf1caf71e2 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -80,7 +80,7 @@ class SingleTrainer(TranspileTrainer): pipe_cmd = "python {} {} {} {} {} {} {} {}".format( reader, "slot", "slot", self._config_yaml, "fake", \ sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding)) - + dataset = fluid.DatasetFactory().create_dataset() dataset.set_batch_size(envs.get_global_env(name + "batch_size")) dataset.set_pipe_command(pipe_cmd) @@ -110,11 +110,13 @@ class SingleTrainer(TranspileTrainer): reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) if sparse_slots is None and dense_slots is None: - reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml) + reader = dataloader_instance.dataloader_by_name( + reader_class, dataset_name, self._config_yaml) reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader") reader_ins = reader_class(self._config_yaml) else: - reader = dataloader_instance.slotdataloader_by_name("", dataset_name, self._config_yaml) + reader = dataloader_instance.slotdataloader_by_name( + "", dataset_name, self._config_yaml) reader_ins = SlotReader(self._config_yaml) if hasattr(reader_ins, 'generate_batch_from_trainfiles'): dataloader.set_sample_list_generator(reader) @@ -122,7 +124,6 @@ class SingleTrainer(TranspileTrainer): dataloader.set_sample_generator(reader, batch_size) return dataloader - def _create_dataset(self, dataset_name): name = "dataset." + dataset_name + "." sparse_slots = envs.get_global_env(name + "sparse_slots") @@ -131,7 +132,8 @@ class SingleTrainer(TranspileTrainer): batch_size = envs.get_global_env(name + "batch_size") type_name = envs.get_global_env(name + "type") if envs.get_platform() != "LINUX": - print("platform ", envs.get_platform(), " change reader to DataLoader") + print("platform ", envs.get_platform(), + " change reader to DataLoader") type_name = "DataLoader" padding = 0 @@ -140,7 +142,6 @@ class SingleTrainer(TranspileTrainer): else: return self._get_dataset(dataset_name) - def init(self, context): for model_dict in self._env["executor"]: self._model[model_dict["name"]] = [None] * 5 @@ -186,12 +187,11 @@ class SingleTrainer(TranspileTrainer): self._model[model_dict["name"]][2] = scope self._model[model_dict["name"]][3] = model self._model[model_dict["name"]][4] = train_program.clone() - for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": self._dataset[dataset["name"]] = self._create_dataset(dataset[ - "name"]) + "name"]) context['status'] = 'startup_pass' @@ -268,7 +268,8 @@ class SingleTrainer(TranspileTrainer): program = self._model[model_name][0].clone() if not model_dict["is_infer"]: program = fluid.compiler.CompiledProgram( - program).with_data_parallel(loss_name=model_class.get_avg_cost().name) + program).with_data_parallel( + loss_name=model_class.get_avg_cost().name) fetch_vars = [] fetch_alias = [] fetch_period = 20 @@ -333,8 +334,10 @@ class SingleTrainer(TranspileTrainer): envs.get_global_env("epoch.save_inference_interval", -1)) if not need_save(epoch_id, save_interval, False): return - feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None) - fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None) + feed_varnames = envs.get_global_env( + "epoch.save_inference_feed_varnames", None) + fetch_varnames = envs.get_global_env( + "epoch.save_inference_fetch_varnames", None) if feed_varnames is None or fetch_varnames is None or feed_varnames == "": return fetch_vars = [ diff --git a/run.py b/run.py index 9b668e8f2a3fb380b21ed328ca53cf59b268ca70..830ec905a5deab896400c755c46306d058c36885 100755 --- a/run.py +++ b/run.py @@ -68,7 +68,7 @@ def get_engine(args): if engine is None: engine = run_extras.get("epoch.trainer_class", None) if engine is None: - engine = "single" + engine = "single" engine = engine.upper() if engine not in engine_choices: raise ValueError("train.engin can not be chosen in {}".format(