未验证 提交 9894fbb1 编写于 作者: H HawChang 提交者: GitHub

fix multilabel predict res label name wrong bug on v1.7 (#736)

Co-authored-by: Nzhanghao55 <zhanghao55@baidu.com>
上级 7faf09eb
...@@ -335,7 +335,6 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -335,7 +335,6 @@ class MultiLabelClassifierTask(ClassifierTask):
config=config, config=config,
hidden_units=hidden_units, hidden_units=hidden_units,
metrics_choices=metrics_choices) metrics_choices=metrics_choices)
self.class_name = list(data_reader.label_map.keys())
def _build_net(self): def _build_net(self):
cls_feats = fluid.layers.dropout( cls_feats = fluid.layers.dropout(
...@@ -415,7 +414,8 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -415,7 +414,8 @@ class MultiLabelClassifierTask(ClassifierTask):
# NOTE: for MultiLabelClassifierTask, the metrics will be used to evaluate all the label # NOTE: for MultiLabelClassifierTask, the metrics will be used to evaluate all the label
# and their mean value will also be reported. # and their mean value will also be reported.
for index, auc in enumerate(auc_list): for index, auc in enumerate(auc_list):
scores["auc_" + self.class_name[index]] = auc_list[index][0] scores["auc_" + self._base_data_reader.dataset.
label_list[index]] = auc_list[index][0]
else: else:
raise ValueError("Not Support Metric: \"%s\"" % metric) raise ValueError("Not Support Metric: \"%s\"" % metric)
return scores, avg_loss, run_speed return scores, avg_loss, run_speed
...@@ -428,7 +428,6 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -428,7 +428,6 @@ class MultiLabelClassifierTask(ClassifierTask):
def _postprocessing(self, run_states): def _postprocessing(self, run_states):
results = [] results = []
label_list = list(self._base_data_reader.label_map.keys())
for batch_state in run_states: for batch_state in run_states:
batch_result = batch_state.run_results batch_result = batch_state.run_results
for sample_id in range(len(batch_result[0])): for sample_id in range(len(batch_result[0])):
...@@ -437,7 +436,9 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -437,7 +436,9 @@ class MultiLabelClassifierTask(ClassifierTask):
self._base_data_reader.dataset.num_labels): self._base_data_reader.dataset.num_labels):
sample_category_prob = batch_result[category_id][sample_id] sample_category_prob = batch_result[category_id][sample_id]
sample_category_value = np.argmax(sample_category_prob) sample_category_value = np.argmax(sample_category_prob)
sample_result.append( sample_result.append({
{label_list[category_id]: sample_category_value}) self._base_data_reader.dataset.label_list[category_id]:
sample_category_value
})
results.append(sample_result) results.append(sample_result)
return results return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册