From 9894fbb1dc8575ae1fa74f32a23cc1363467461b Mon Sep 17 00:00:00 2001 From: HawChang Date: Tue, 7 Jul 2020 09:38:31 +0800 Subject: [PATCH] fix multilabel predict res label name wrong bug on v1.7 (#736) Co-authored-by: zhanghao55 --- paddlehub/finetune/task/classifier_task.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddlehub/finetune/task/classifier_task.py b/paddlehub/finetune/task/classifier_task.py index b137afdc..f3f700ed 100644 --- a/paddlehub/finetune/task/classifier_task.py +++ b/paddlehub/finetune/task/classifier_task.py @@ -335,7 +335,6 @@ class MultiLabelClassifierTask(ClassifierTask): config=config, hidden_units=hidden_units, metrics_choices=metrics_choices) - self.class_name = list(data_reader.label_map.keys()) def _build_net(self): cls_feats = fluid.layers.dropout( @@ -415,7 +414,8 @@ class MultiLabelClassifierTask(ClassifierTask): # NOTE: for MultiLabelClassifierTask, the metrics will be used to evaluate all the label # and their mean value will also be reported. for index, auc in enumerate(auc_list): - scores["auc_" + self.class_name[index]] = auc_list[index][0] + scores["auc_" + self._base_data_reader.dataset. + label_list[index]] = auc_list[index][0] else: raise ValueError("Not Support Metric: \"%s\"" % metric) return scores, avg_loss, run_speed @@ -428,7 +428,6 @@ class MultiLabelClassifierTask(ClassifierTask): def _postprocessing(self, run_states): results = [] - label_list = list(self._base_data_reader.label_map.keys()) for batch_state in run_states: batch_result = batch_state.run_results for sample_id in range(len(batch_result[0])): @@ -437,7 +436,9 @@ class MultiLabelClassifierTask(ClassifierTask): self._base_data_reader.dataset.num_labels): sample_category_prob = batch_result[category_id][sample_id] sample_category_value = np.argmax(sample_category_prob) - sample_result.append( - {label_list[category_id]: sample_category_value}) + sample_result.append({ + self._base_data_reader.dataset.label_list[category_id]: + sample_category_value + }) results.append(sample_result) return results -- GitLab