diff --git a/core/model.py b/core/model.py index 2d34b5b068b71ac480575b08e95d67d0208f3841..cfe71f2a5db76b2585b6df23915724909ef68033 100755 --- a/core/model.py +++ b/core/model.py @@ -87,9 +87,13 @@ class Model(object): if dataset_class == "DataLoader": self._init_dataloader() - def _init_dataloader(self): + def _init_dataloader(self, is_infer=False): + if is_infer: + data = self._infer_data_var + else: + data = self._data_var self._data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._data_var, + feed_list=data, capacity=64, use_double_buffer=False, iterable=False) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 501afd477a671d054e86011e34d8fc39640e6f91..db1bc8efb43e937b9e5446f55b3ec1a6e5bf5544 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -93,7 +93,10 @@ class SingleTrainer(TranspileTrainer): for model_dict in self._env["executor"]: if model_dict["dataset_name"] == dataset_name: model = self._model[model_dict["name"]][3] - inputs = model.get_inputs() + if model_dict["is_infer"]: + inputs = model._infer_data_var + else: + inputs = model._data_var dataset.set_use_var(inputs) break return dataset @@ -158,23 +161,32 @@ class SingleTrainer(TranspileTrainer): envs.path_adapter(self._env["workspace"])) model = envs.lazy_instance_by_fliename( model_path, "Model")(self._env) - model._data_var = model.input_data( - dataset_name=model_dict["dataset_name"]) + is_infer = model_dict.get("is_infer", False) + if is_infer: + model._infer_data_var = model.input_data( + dataset_name=model_dict["dataset_name"]) + else: + model._data_var = model.input_data( + dataset_name=model_dict["dataset_name"]) if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": - model._init_dataloader() + model._init_dataloader(is_infer=is_infer) self._get_dataloader(dataset_name, model._data_loader) - model.net(model._data_var, - is_infer=model_dict.get("is_infer", False)) - optimizer = model._build_optimizer(opt_name, opt_lr, - opt_strategy) - optimizer.minimize(model._cost) + if is_infer: + model.net(model._infer_data_var, True) + else: + model.net(model._data_var, False) + optimizer = model._build_optimizer(opt_name, opt_lr, + opt_strategy) + optimizer.minimize(model._cost) + model_dict["is_infer"] = is_infer self._model[model_dict["name"]][0] = train_program self._model[model_dict["name"]][1] = startup_program self._model[model_dict["name"]][2] = scope self._model[model_dict["name"]][3] = model self._model[model_dict["name"]][4] = train_program.clone() + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": @@ -223,7 +235,10 @@ class SingleTrainer(TranspileTrainer): fetch_vars = [] fetch_alias = [] fetch_period = 20 - metrics = model_class.get_metrics() + if model_dict["is_infer"]: + metrics = model_class.get_infer_results() + else: + metrics = model_class.get_metrics() if metrics: fetch_vars = metrics.values() fetch_alias = metrics.keys() @@ -231,24 +246,36 @@ class SingleTrainer(TranspileTrainer): program = self._model[model_name][0] reader = self._dataset[reader_name] with fluid.scope_guard(scope): - self._exe.train_from_dataset( - program=program, - dataset=reader, - fetch_list=fetch_vars, - fetch_info=fetch_alias, - print_period=fetch_period) + if model_dict["is_infer"]: + self._exe.infer_from_dataset( + program=program, + dataset=reader, + fetch_list=fetch_vars, + fetch_info=fetch_alias, + print_period=fetch_period) + else: + self._exe.train_from_dataset( + program=program, + dataset=reader, + fetch_list=fetch_vars, + fetch_info=fetch_alias, + print_period=fetch_period) def _executor_dataloader_train(self, model_dict): reader_name = model_dict["dataset_name"] model_name = model_dict["name"] model_class = self._model[model_name][3] program = self._model[model_name][0].clone() - program = fluid.compiler.CompiledProgram( - program).with_data_parallel(loss_name=model_class.get_avg_cost().name) + if not model_dict["is_infer"]: + program = fluid.compiler.CompiledProgram( + program).with_data_parallel(loss_name=model_class.get_avg_cost().name) fetch_vars = [] fetch_alias = [] fetch_period = 20 - metrics = model_class.get_metrics() + if model_dict["is_infer"]: + metrics = model_class.get_infer_results() + else: + metrics = model_class.get_metrics() if metrics: fetch_vars = metrics.values() fetch_alias = metrics.keys()