From 0399449f489dcf30063db8b5c8d30c296c2e1c11 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Tue, 25 Jun 2019 20:01:50 +0800 Subject: [PATCH] Update task finetune api to return run states --- demo/elmo/predict.py | 3 ++- demo/image-classification/predict.py | 3 ++- demo/multi-label-classification/predict.py | 3 ++- demo/senta/predict.py | 3 ++- demo/sequence-labeling/predict.py | 3 ++- demo/text-classification/predict.py | 3 ++- paddlehub/finetune/task.py | 6 ++++-- 7 files changed, 16 insertions(+), 8 deletions(-) diff --git a/demo/elmo/predict.py b/demo/elmo/predict.py index be056b73..3fb9b714 100644 --- a/demo/elmo/predict.py +++ b/demo/elmo/predict.py @@ -168,7 +168,8 @@ if __name__ == '__main__': ] index = 0 - results = elmo_task.predict(data=data) + run_states = elmo_task.predict(data=data) + results = [run_state.run_results for run_state in run_states] for batch_result in results: # get predict index batch_result = np.argmax(batch_result, axis=2)[0] diff --git a/demo/image-classification/predict.py b/demo/image-classification/predict.py index 532d813f..ab7d34d9 100644 --- a/demo/image-classification/predict.py +++ b/demo/image-classification/predict.py @@ -76,7 +76,8 @@ def predict(args): label_map = dataset.label_dict() index = 0 # get classification result - results = task.predict(data=data) + run_states = task.predict(data=data) + results = [run_state.run_results for run_state in run_states] for batch_result in results: # get predict index batch_result = np.argmax(batch_result, axis=2)[0] diff --git a/demo/multi-label-classification/predict.py b/demo/multi-label-classification/predict.py index 5e7caa99..138c5ade 100644 --- a/demo/multi-label-classification/predict.py +++ b/demo/multi-label-classification/predict.py @@ -99,7 +99,8 @@ if __name__ == '__main__': ] index = 0 - results = multi_label_cls_task.predict(data=data) + run_states = multi_label_cls_task.predict(data=data) + results = [run_state.run_results for run_state in run_states] for result in results: # get predict index label_ids = [] diff --git a/demo/senta/predict.py b/demo/senta/predict.py index 4d22b7bb..1c258343 100644 --- a/demo/senta/predict.py +++ b/demo/senta/predict.py @@ -60,7 +60,8 @@ if __name__ == '__main__': data = ["这家餐厅很好吃", "这部电影真的很差劲"] - results = cls_task.predict(data=data) + run_states = cls_task.predict(data=data) + results = [run_state.run_results for run_state in run_states] index = 0 for batch_result in results: batch_result = np.argmax(batch_result, axis=2)[0] diff --git a/demo/sequence-labeling/predict.py b/demo/sequence-labeling/predict.py index fb26b86a..01fa4280 100644 --- a/demo/sequence-labeling/predict.py +++ b/demo/sequence-labeling/predict.py @@ -96,7 +96,8 @@ if __name__ == '__main__': ["不过重在晋趣,略增明人气息,妙在集古有道、不露痕迹罢了。"], ] - results = seq_label_task.predict(data=data) + run_states = seq_label_task.predict(data=data) + results = [run_state.run_results for run_state in run_states] for num_batch, batch_results in enumerate(results): infers = batch_results[0].reshape([-1]).astype(np.int32).tolist() diff --git a/demo/text-classification/predict.py b/demo/text-classification/predict.py index 1465d0d6..e0088dba 100644 --- a/demo/text-classification/predict.py +++ b/demo/text-classification/predict.py @@ -97,7 +97,8 @@ if __name__ == '__main__': ] index = 0 - results = cls_task.predict(data=data) + run_states = cls_task.predict(data=data) + results = [run_state.run_results for run_state in run_states] for batch_result in results: # get predict index batch_result = np.argmax(batch_result, axis=2)[0] diff --git a/paddlehub/finetune/task.py b/paddlehub/finetune/task.py index 13c776d4..d438eaf6 100644 --- a/paddlehub/finetune/task.py +++ b/paddlehub/finetune/task.py @@ -498,7 +498,7 @@ class BasicTask(object): self.exe, dirname=dirname, main_program=self.main_program) def finetune_and_eval(self): - self.finetune(do_eval=True) + return self.finetune(do_eval=True) def finetune(self, do_eval=False): # Start to finetune @@ -519,6 +519,7 @@ class BasicTask(object): self.eval(phase="test") self._finetune_end_event(run_states) + return run_states def eval(self, phase="dev"): with self.phase_guard(phase=phase): @@ -526,6 +527,7 @@ class BasicTask(object): self._eval_start_event() run_states = self._run() self._eval_end_event(run_states) + return run_states def predict(self, data, load_best_model=True): with self.phase_guard(phase="predict"): @@ -539,7 +541,7 @@ class BasicTask(object): run_states = self._run() self._predict_end_event(run_states) self._predict_data = None - return [run_state.run_results for run_state in run_states] + return run_states def _run(self, do_eval=False): with fluid.program_guard(self.main_program, self.startup_program): -- GitLab