diff --git a/demo/multi-label-classification/predict.py b/demo/multi-label-classification/predict.py index c5052b96897d1b83d072bcc38925836535cdf919..f6f4de5d16b2c2f1ae92394021b52508a5bb11a4 100644 --- a/demo/multi-label-classification/predict.py +++ b/demo/multi-label-classification/predict.py @@ -41,7 +41,7 @@ args = parser.parse_args() if __name__ == '__main__': # Load Paddlehub BERT pretrained model - module = hub.Module(name="ernie_eng_base.hub_module") + module = hub.Module(name="ernie_v2_eng_base") inputs, outputs, program = module.context( trainable=True, max_seq_len=args.max_seq_len) @@ -97,12 +97,14 @@ if __name__ == '__main__': index = 0 run_states = multi_label_cls_task.predict(data=data) - results = [run_state.run_results for run_state in run_states] - for result in results: - # get predict index - label_ids = [] - for i in range(dataset.num_labels): - label_val = np.argmax(result[i]) - label_ids.append(label_val) - print("%s\tpredict=%s" % (data[index][0], label_ids)) - index += 1 + + all_result = [] + for batch_state in run_states: + batch_result = batch_state.run_results + for sample_id in range(len(batch_result[0])): + sample_result = [] + for category_id in range(dataset.num_labels): + sample_category_prob = batch_result[category_id][sample_id] + sample_result.append(np.argmax(sample_category_prob)) + all_result.append(sample_result) + print(all_result) diff --git a/paddlehub/finetune/task/base_task.py b/paddlehub/finetune/task/base_task.py index cc98d1d40a6b84f4ad5e3f76c3481297b4d495f9..4035705291a6fe2955e88396985309c58c4e85eb 100644 --- a/paddlehub/finetune/task/base_task.py +++ b/paddlehub/finetune/task/base_task.py @@ -767,7 +767,7 @@ class BaseTask(object): self._eval_end_event(run_states) return run_states - def predict(self, data, load_best_model=True): + def predict(self, data, load_best_model=True, return_result=False): with self.phase_guard(phase="predict"): if load_best_model: self.init_if_load_best_model() @@ -778,8 +778,17 @@ class BaseTask(object): run_states = self._run() self._predict_end_event(run_states) self._predict_data = None + if return_result: + return self._postprocessing(run_states) return run_states + def _postprocessing(self, run_states): + results = [] + for batch_state in run_states: + batch_result = batch_state.run_results[0] + results += [result[0] for result in batch_result] + return results + def _run(self, do_eval=False): with fluid.program_guard(self.main_program, self.startup_program): if self.config.use_pyreader: diff --git a/paddlehub/finetune/task/classifier_task.py b/paddlehub/finetune/task/classifier_task.py index 9a3c00d6be3813c4a3db2885b603f67f63808ed5..0d747ad098adddd359bebede81f6edfc9c5f7f70 100644 --- a/paddlehub/finetune/task/classifier_task.py +++ b/paddlehub/finetune/task/classifier_task.py @@ -134,6 +134,22 @@ class ClassifierTask(BaseTask): return scores, avg_loss, run_speed + def _postprocessing(self, run_states): + try: + id2label = { + val: key + for key, val in self._base_data_reader.label_map.items() + } + except: + raise Exception( + "image-classification does not support return_result now") + results = [] + for batch_state in run_states: + batch_result = batch_state.run_results + batch_infer = np.argmax(batch_result, axis=2)[0] + results += [id2label[sample_infer] for sample_infer in batch_infer] + return results + ImageClassifierTask = ClassifierTask @@ -301,3 +317,16 @@ class MultiLabelClassifierTask(ClassifierTask): if self.is_train_phase or self.is_test_phase: return [metric.name for metric in self.metrics] + [self.loss.name] return self.outputs + + def _postprocessing(self, run_states): + results = [] + for batch_state in run_states: + batch_result = batch_state.run_results + for sample_id in range(len(batch_result[0])): + sample_result = [] + for category_id in range( + self._base_data_reader.dataset.num_labels): + sample_category_prob = batch_result[category_id][sample_id] + sample_result.append(np.argmax(sample_category_prob)) + results.append(sample_result) + return results diff --git a/paddlehub/finetune/task/reading_comprehension_task.py b/paddlehub/finetune/task/reading_comprehension_task.py index 13e4989540a5874639e23aabc3b1a10d2c379f77..ccc590ee36906611d33b52c5a4b03cfcfba2c3b3 100644 --- a/paddlehub/finetune/task/reading_comprehension_task.py +++ b/paddlehub/finetune/task/reading_comprehension_task.py @@ -171,18 +171,15 @@ def get_final_text(pred_text, orig_text, do_lower_case, is_english): return output_text -def write_predictions(all_examples, all_features, all_results, n_best_size, - max_answer_length, do_lower_case, output_prediction_file, - output_nbest_file, output_null_log_odds_file, - version_2_with_negative, null_score_diff_threshold, - is_english): +def get_predictions(all_examples, all_features, all_results, n_best_size, + max_answer_length, do_lower_case, version_2_with_negative, + null_score_diff_threshold, is_english): + _PrelimPrediction = collections.namedtuple("PrelimPrediction", [ "feature_index", "start_index", "end_index", "start_logit", "end_logit" ]) - _NbestPrediction = collections.namedtuple( "NbestPrediction", ["text", "start_logit", "end_logit"]) - example_index_to_features = collections.defaultdict(list) for feature in all_features: example_index_to_features[feature.example_index].append(feature) @@ -363,25 +360,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, all_predictions[example.qas_id] = best_non_null_entry.text all_nbest_json[example.qas_id] = nbest_json - """Write final predictions to the json file and log-odds of null if needed.""" - with open(output_prediction_file, "w") as writer: - logger.info("Writing predictions to: %s" % (output_prediction_file)) - writer.write( - json.dumps(all_predictions, indent=4, ensure_ascii=is_english) + - "\n") - - with open(output_nbest_file, "w") as writer: - logger.info("Writing nbest to: %s" % (output_nbest_file)) - writer.write( - json.dumps(all_nbest_json, indent=4, ensure_ascii=is_english) + - "\n") - - if version_2_with_negative: - logger.info("Writing null_log_odds to: %s" % (output_nbest_file)) - with open(output_null_log_odds_file, "w") as writer: - writer.write( - json.dumps(scores_diff_json, indent=4, ensure_ascii=is_english) - + "\n") + + return all_predictions, all_nbest_json, scores_diff_json class ReadingComprehensionTask(BaseTask): @@ -419,6 +399,8 @@ class ReadingComprehensionTask(BaseTask): self.null_score_diff_threshold = null_score_diff_threshold self.n_best_size = n_best_size self.max_answer_length = max_answer_length + self.RawResult = collections.namedtuple( + "RawResult", ["unique_id", "start_logits", "end_logits"]) self.RawResult = collections.namedtuple( "RawResult", ["unique_id", "start_logits", "end_logits"]) @@ -522,24 +504,15 @@ class ReadingComprehensionTask(BaseTask): scores = OrderedDict() # If none of metrics has been implemented, loss will be used to evaluate. if self.is_test_phase: - output_prediction_file = os.path.join(self.config.checkpoint_dir, - "predictions.json") - output_nbest_file = os.path.join(self.config.checkpoint_dir, - "nbest_predictions.json") - output_null_log_odds_file = os.path.join(self.config.checkpoint_dir, - "null_odds.json") all_examples = self.data_reader.all_examples[self.phase] all_features = self.data_reader.all_features[self.phase] - write_predictions( + all_predictions, all_nbest_json, scores_diff_json = get_predictions( all_examples=all_examples, all_features=all_features, all_results=all_results, n_best_size=self.n_best_size, max_answer_length=self.max_answer_length, do_lower_case=True, - output_prediction_file=output_prediction_file, - output_nbest_file=output_nbest_file, - output_null_log_odds_file=output_null_log_odds_file, version_2_with_negative=self.version_2_with_negative, null_score_diff_threshold=self.null_score_diff_threshold, is_english=self.is_english) @@ -558,25 +531,17 @@ class ReadingComprehensionTask(BaseTask): else: raise Exception("Error phase: %s when runing _calculate_metrics" % self.phase) - with open( - output_prediction_file, 'r', - encoding="utf8") as prediction_file: - predictions = json.load(prediction_file) if self.sub_task == "squad": - scores = squad1_evaluate.evaluate(dataset, predictions) + scores = squad1_evaluate.evaluate(dataset, all_predictions) elif self.sub_task == "squad2.0": - with open( - output_null_log_odds_file, 'r', - encoding="utf8") as odds_file: - na_probs = json.load(odds_file) - scores = squad2_evaluate.evaluate(dataset, predictions, - na_probs) + scores = squad2_evaluate.evaluate(dataset, all_predictions, + scores_diff_json) elif self.sub_task in ["cmrc2018", "drcd"]: - scores = cmrc2018_evaluate.get_eval(dataset, predictions) + scores = cmrc2018_evaluate.get_eval(dataset, all_predictions) return scores, avg_loss, run_speed - def _default_predict_end_event(self, run_states): + def _postprocessing(self, run_states): all_results = [] for run_state in run_states: np_unique_ids = run_state.run_results[0] @@ -591,29 +556,16 @@ class ReadingComprehensionTask(BaseTask): unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) - # If none of metrics has been implemented, loss will be used to evaluate. - output_prediction_file = os.path.join(self.config.checkpoint_dir, - "predict_predictions.json") - output_nbest_file = os.path.join(self.config.checkpoint_dir, - "predict_nbest_predictions.json") - output_null_log_odds_file = os.path.join(self.config.checkpoint_dir, - "predict_null_odds.json") all_examples = self.data_reader.all_examples[self.phase] all_features = self.data_reader.all_features[self.phase] - write_predictions( + all_predictions, all_nbest_json, scores_diff_json = get_predictions( all_examples=all_examples, all_features=all_features, all_results=all_results, n_best_size=self.n_best_size, max_answer_length=self.max_answer_length, do_lower_case=True, - output_prediction_file=output_prediction_file, - output_nbest_file=output_nbest_file, - output_null_log_odds_file=output_null_log_odds_file, version_2_with_negative=self.version_2_with_negative, null_score_diff_threshold=self.null_score_diff_threshold, is_english=self.is_english) - - logger.info("PaddleHub predict finished.") - - logger.info("You can see the prediction in %s" % output_prediction_file) + return all_predictions diff --git a/paddlehub/finetune/task/regression_task.py b/paddlehub/finetune/task/regression_task.py index 3f723d9d410c4d428bcf07b48c281f9ae31079de..3ddfbaf2cc27c19ca525c6220464396b1cd194b0 100644 --- a/paddlehub/finetune/task/regression_task.py +++ b/paddlehub/finetune/task/regression_task.py @@ -120,3 +120,10 @@ class RegressionTask(BaseTask): else: raise ValueError("Not Support Metric: \"%s\"" % metric) return scores, avg_loss, run_speed + + def _postprocessing(self, run_states): + results = [] + for batch_state in run_states: + batch_result = batch_state.run_results[0] + results += [result[0] for result in batch_result] + return results diff --git a/paddlehub/finetune/task/sequence_task.py b/paddlehub/finetune/task/sequence_task.py index 1f4b8da911df5cb3ed644005e65fc1cbbf537a73..cff27989b9cde0eb99aba510dfcca8a2e6a881aa 100644 --- a/paddlehub/finetune/task/sequence_task.py +++ b/paddlehub/finetune/task/sequence_task.py @@ -216,3 +216,22 @@ class SequenceLabelTask(BaseTask): elif self.is_predict_phase: return [self.ret_infers.name] + [self.seq_len.name] return [output.name for output in self.outputs] + + def _postprocessing(self, run_states): + id2label = { + val: key + for key, val in self._base_data_reader.label_map.items() + } + results = [] + for batch_states in run_states: + batch_results = batch_states.run_results + batch_infers = batch_results[0].reshape([-1]).astype( + np.int32).tolist() + seq_lens = batch_results[1].reshape([-1]).astype(np.int32).tolist() + current_id = 0 + for length in seq_lens: + seq_infers = batch_infers[current_id:current_id + length] + seq_result = list(map(id2label.get, seq_infers[1:-1])) + current_id += int(length) + results.append(seq_result) + return results