未验证 提交 ebc7ffc3 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix detection_map. test=develop (#22705)

上级 ee8b22fb
...@@ -420,8 +420,11 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -420,8 +420,11 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first; int label = it->first;
int label_num_pos = it->second; int label_num_pos = it->second;
if (label_num_pos == background_label || if (label_num_pos == background_label) {
true_pos.find(label) == true_pos.end()) { continue;
}
if (true_pos.find(label) == true_pos.end()) {
count++;
continue; continue;
} }
auto label_true_pos = true_pos.find(label)->second; auto label_true_pos = true_pos.find(label)->second;
......
...@@ -181,7 +181,10 @@ class TestDetectionMAPOp(OpTest): ...@@ -181,7 +181,10 @@ class TestDetectionMAPOp(OpTest):
false_pos[label].append([score, fp]) false_pos[label].append([score, fp])
for (label, label_pos_num) in six.iteritems(label_count): for (label, label_pos_num) in six.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue if label_pos_num == 0: continue
if label not in true_pos:
count += 1
continue
label_true_pos = true_pos[label] label_true_pos = true_pos[label]
label_false_pos = false_pos[label] label_false_pos = false_pos[label]
...@@ -281,5 +284,30 @@ class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp): ...@@ -281,5 +284,30 @@ class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp):
self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]] self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]]
class TestDetectionMAPOp11PointWithClassNoTP(TestDetectionMAPOp):
def init_test_case(self):
self.overlap_threshold = 0.3
self.evaluate_difficult = True
self.ap_type = "11point"
self.label_lod = [[2]]
# label difficult xmin ymin xmax ymax
self.label = [[2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]]
# label score xmin ymin xmax ymax difficult
self.detect_lod = [[1]]
self.detect = [[1, 0.2, 0.8, 0.1, 1.0, 0.3]]
# label score true_pos false_pos
self.tf_pos_lod = [[3, 4]]
self.tf_pos = [[1, 0.2, 1, 0]]
self.class_pos_count = []
self.true_pos_lod = [[]]
self.true_pos = [[]]
self.false_pos_lod = [[]]
self.false_pos = [[]]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册