未验证 提交 9894fbb1 编写于 作者: H HawChang 提交者: GitHub

fix multilabel predict res label name wrong bug on v1.7 (#736)

Co-authored-by: Nzhanghao55 <zhanghao55@baidu.com>
上级 7faf09eb
......@@ -335,7 +335,6 @@ class MultiLabelClassifierTask(ClassifierTask):
config=config,
hidden_units=hidden_units,
metrics_choices=metrics_choices)
self.class_name = list(data_reader.label_map.keys())
def _build_net(self):
cls_feats = fluid.layers.dropout(
......@@ -415,7 +414,8 @@ class MultiLabelClassifierTask(ClassifierTask):
# NOTE: for MultiLabelClassifierTask, the metrics will be used to evaluate all the label
# and their mean value will also be reported.
for index, auc in enumerate(auc_list):
scores["auc_" + self.class_name[index]] = auc_list[index][0]
scores["auc_" + self._base_data_reader.dataset.
label_list[index]] = auc_list[index][0]
else:
raise ValueError("Not Support Metric: \"%s\"" % metric)
return scores, avg_loss, run_speed
......@@ -428,7 +428,6 @@ class MultiLabelClassifierTask(ClassifierTask):
def _postprocessing(self, run_states):
results = []
label_list = list(self._base_data_reader.label_map.keys())
for batch_state in run_states:
batch_result = batch_state.run_results
for sample_id in range(len(batch_result[0])):
......@@ -437,7 +436,9 @@ class MultiLabelClassifierTask(ClassifierTask):
self._base_data_reader.dataset.num_labels):
sample_category_prob = batch_result[category_id][sample_id]
sample_category_value = np.argmax(sample_category_prob)
sample_result.append(
{label_list[category_id]: sample_category_value})
sample_result.append({
self._base_data_reader.dataset.label_list[category_id]:
sample_category_value
})
results.append(sample_result)
return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册