From ab06f63e1c6703ee1e5285199f4db6f7075c028b Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Sun, 19 Jan 2020 11:41:34 +0800 Subject: [PATCH] Add inference model (#327) * add-inference-model --- paddlehub/finetune/checkpoint.py | 4 ---- paddlehub/finetune/task/base_task.py | 10 +++++----- paddlehub/finetune/task/classifier_task.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/paddlehub/finetune/checkpoint.py b/paddlehub/finetune/checkpoint.py index b3ace5b6..04ddb775 100644 --- a/paddlehub/finetune/checkpoint.py +++ b/paddlehub/finetune/checkpoint.py @@ -74,10 +74,6 @@ def save_checkpoint(checkpoint_dir, ckpt = checkpoint_pb2.CheckPoint() model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step) - logger.info("Saving model checkpoint to {}".format(model_saved_dir)) - fluid.io.save_persistables( - exe, dirname=model_saved_dir, main_program=main_program) - ckpt.current_epoch = current_epoch ckpt.global_step = global_step ckpt.latest_model_dir = model_saved_dir diff --git a/paddlehub/finetune/task/base_task.py b/paddlehub/finetune/task/base_task.py index 55eb0666..49831cb9 100644 --- a/paddlehub/finetune/task/base_task.py +++ b/paddlehub/finetune/task/base_task.py @@ -662,11 +662,7 @@ class BaseTask(object): "best_model") logger.eval("best model saved to %s [best %s=%.5f]" % (model_saved_dir, main_metric, main_value)) - - save_result = fluid.io.save_persistables( - executor=self.exe, - dirname=model_saved_dir, - main_program=self.main_program) + self.save_inference_model(dirname=model_saved_dir) def _default_log_interval_event(self, run_states): scores, avg_loss, run_speed = self._calculate_metrics(run_states) @@ -717,6 +713,10 @@ class BaseTask(object): # NOTE: current saved checkpoint machanism is not completed, # it can't restore dataset training status def save_checkpoint(self): + model_saved_dir = os.path.join(self.config.checkpoint_dir, + "step_%d" % self.current_step) + logger.info("Saving model checkpoint to {}".format(model_saved_dir)) + self.save_inference_model(dirname=model_saved_dir) save_checkpoint( checkpoint_dir=self.config.checkpoint_dir, current_epoch=self.current_epoch, diff --git a/paddlehub/finetune/task/classifier_task.py b/paddlehub/finetune/task/classifier_task.py index df0999db..0ab62405 100644 --- a/paddlehub/finetune/task/classifier_task.py +++ b/paddlehub/finetune/task/classifier_task.py @@ -317,7 +317,7 @@ class MultiLabelClassifierTask(ClassifierTask): def fetch_list(self): if self.is_train_phase or self.is_test_phase: return [metric.name for metric in self.metrics] + [self.loss.name] - return self.outputs + return [output.name for output in self.outputs] def _postprocessing(self, run_states): results = [] -- GitLab