From dd378956f5de992f1566b627ed98e9db90847341 Mon Sep 17 00:00:00 2001 From: malin10 Date: Fri, 8 May 2020 11:46:08 +0800 Subject: [PATCH] add infer --- fleet_rec/core/model.py | 9 ++++ fleet_rec/core/trainers/single_trainer.py | 50 +++++++++++++++++-- fleet_rec/core/trainers/transpiler_trainer.py | 38 ++++++++------ fleet_rec/core/utils/dataloader_instance.py | 4 +- 4 files changed, 82 insertions(+), 19 deletions(-) diff --git a/fleet_rec/core/model.py b/fleet_rec/core/model.py index 528be0bf..9d1815bf 100644 --- a/fleet_rec/core/model.py +++ b/fleet_rec/core/model.py @@ -16,7 +16,10 @@ class Model(object): self._cost = None self._metrics = {} self._data_var = [] + self._infer_data_var = [] + self._infer_results = {} self._data_loader = None + self._infer_data_loader = None self._fetch_interval = 20 self._namespace = "train.model" self._platform = envs.get_platform() @@ -24,6 +27,12 @@ class Model(object): def get_inputs(self): return self._data_var + def get_infer_inputs(self): + return self._infer_data_var + + def get_infer_results(self): + return self._infer_results + def get_cost_op(self): """R """ diff --git a/fleet_rec/core/trainers/single_trainer.py b/fleet_rec/core/trainers/single_trainer.py index 989297a0..9cf97082 100644 --- a/fleet_rec/core/trainers/single_trainer.py +++ b/fleet_rec/core/trainers/single_trainer.py @@ -59,7 +59,7 @@ class SingleTrainer(TranspileTrainer): def dataloader_train(self, context): self._exe.run(fluid.default_startup_program()) - reader = self._get_dataloader() + reader = self._get_dataloader("TRAIN") epochs = envs.get_global_env("train.epochs") program = fluid.compiler.CompiledProgram( @@ -95,13 +95,14 @@ class SingleTrainer(TranspileTrainer): batch_id += 1 except fluid.core.EOFException: reader.reset() + self.save(epoch, "train", is_fleet=False) context['status'] = 'infer_pass' def dataset_train(self, context): # run startup program at once self._exe.run(fluid.default_startup_program()) - dataset = self._get_dataset() + dataset = self._get_dataset("TRAIN") epochs = envs.get_global_env("train.epochs") for i in range(epochs): @@ -109,11 +110,54 @@ class SingleTrainer(TranspileTrainer): dataset=dataset, fetch_list=self.fetch_vars, fetch_info=self.fetch_alias, - print_period=self.fetch_period) + print_period=1, + debug=True) self.save(i, "train", is_fleet=False) context['status'] = 'infer_pass' def infer(self, context): + infer_program = fluid.Program() + startup_program = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(infer_program, startup_program): + self.model.infer_net() + + reader = self._get_dataloader("Evaluate") + + metrics_varnames = [] + metrics_format = [] + + metrics_format.append("{}: {{}}".format("epoch")) + metrics_format.append("{}: {{}}".format("batch")) + + for name, var in self.model.get_infer_results().items(): + metrics_varnames.append(var.name) + metrics_format.append("{}: {{}}".format(name)) + + metrics_format = ", ".join(metrics_format) + self._exe.run(startup_program) + + for (epoch, model_dir) in self.increment_models: + print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir)) + program = infer_program.clone() + fluid.io.load_persistables(self._exe, model_dir, program) + reader.start() + batch_id = 0 + try: + while True: + metrics_rets = self._exe.run( + program=program, + fetch_list=metrics_varnames) + + metrics = [epoch, batch_id] + metrics.extend(metrics_rets) + + if batch_id % 2 == 0 and batch_id != 0: + print(metrics_format.format(*metrics)) + batch_id += 1 + except fluid.core.EOFException: + reader.reset() + context['status'] = 'terminal_pass' def terminal(self, context): diff --git a/fleet_rec/core/trainers/transpiler_trainer.py b/fleet_rec/core/trainers/transpiler_trainer.py index a4875044..eb7d8b0b 100644 --- a/fleet_rec/core/trainers/transpiler_trainer.py +++ b/fleet_rec/core/trainers/transpiler_trainer.py @@ -36,28 +36,37 @@ class TranspileTrainer(Trainer): def processor_register(self): print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first") - def _get_dataloader(self): - namespace = "train.reader" - dataloader = self.model._data_loader + def _get_dataloader(self, state): + if state == "TRAIN": + dataloader = self.model._data_loader + namespace = "train.reader" + else: + dataloader = self.model._infer_data_loader + namespace = "evaluate.reader" + batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) - reader = dataloader_instance.dataloader(reader_class, "TRAIN", self._config_yaml) + reader = dataloader_instance.dataloader(reader_class, state, self._config_yaml) dataloader.set_sample_generator(reader, batch_size) return dataloader - def _get_dataset(self): - namespace = "train.reader" + def _get_dataset(self, state): + if state == "TRAIN": + inputs = self.model.get_inputs() + namespace = "train.reader" + train_data_path = envs.get_global_env("train_data_path", None, namespace) + else: + inputs = self.model.get_infer_inputs() + namespace = "evaluate.reader" + train_data_path = envs.get_global_env("test_data_path", None, namespace) - inputs = self.model.get_inputs() threads = int(envs.get_runtime_environ("train.trainer.threads")) batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') - pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) - - train_data_path = envs.get_global_env("train_data_path", None, namespace) + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, state, self._config_yaml) if train_data_path.startswith("fleetrec::"): package_base = envs.get_runtime_environ("PACKAGE_BASE") @@ -92,13 +101,13 @@ class TranspileTrainer(Trainer): if not need_save(epoch_id, save_interval, False): return - - print("save inference model is not supported now.") - return + + # print("save inference model is not supported now.") + # return feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace) fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace) - fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames] + fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames] dirname = envs.get_global_env("save.inference.dirname", None, namespace) assert dirname is not None @@ -129,6 +138,7 @@ class TranspileTrainer(Trainer): save_persistables() save_inference_model() + def instance(self, context): models = envs.get_global_env("train.model.models") diff --git a/fleet_rec/core/utils/dataloader_instance.py b/fleet_rec/core/utils/dataloader_instance.py index eb7e5fd6..3f86f908 100644 --- a/fleet_rec/core/utils/dataloader_instance.py +++ b/fleet_rec/core/utils/dataloader_instance.py @@ -22,13 +22,13 @@ from fleetrec.core.utils.envs import get_runtime_environ def dataloader(readerclass, train, yaml_file): - namespace = "train.reader" - if train == "TRAIN": reader_name = "TrainReader" + namespace = "train.reader" data_path = get_global_env("train_data_path", None, namespace) else: reader_name = "EvaluateReader" + namespace = "evaluate.reader" data_path = get_global_env("test_data_path", None, namespace) if data_path.startswith("fleetrec::"): -- GitLab