From ebc7ffc300a4095c568502116f9808a9ff1ebc9b Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Fri, 28 Feb 2020 11:15:04 +0800 Subject: [PATCH] fix detection_map. test=develop (#22705) --- paddle/fluid/operators/detection_map_op.h | 7 +++-- .../tests/unittests/test_detection_map_op.py | 30 ++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h index dd5d138a1e9..10becb080ff 100644 --- a/paddle/fluid/operators/detection_map_op.h +++ b/paddle/fluid/operators/detection_map_op.h @@ -420,8 +420,11 @@ class DetectionMAPOpKernel : public framework::OpKernel { for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { int label = it->first; int label_num_pos = it->second; - if (label_num_pos == background_label || - true_pos.find(label) == true_pos.end()) { + if (label_num_pos == background_label) { + continue; + } + if (true_pos.find(label) == true_pos.end()) { + count++; continue; } auto label_true_pos = true_pos.find(label)->second; diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py index 0c5343a97d5..93ab4a73906 100644 --- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py +++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py @@ -181,7 +181,10 @@ class TestDetectionMAPOp(OpTest): false_pos[label].append([score, fp]) for (label, label_pos_num) in six.iteritems(label_count): - if label_pos_num == 0 or label not in true_pos: continue + if label_pos_num == 0: continue + if label not in true_pos: + count += 1 + continue label_true_pos = true_pos[label] label_false_pos = false_pos[label] @@ -281,5 +284,30 @@ class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp): self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]] +class TestDetectionMAPOp11PointWithClassNoTP(TestDetectionMAPOp): + def init_test_case(self): + self.overlap_threshold = 0.3 + self.evaluate_difficult = True + self.ap_type = "11point" + + self.label_lod = [[2]] + # label difficult xmin ymin xmax ymax + self.label = [[2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]] + + # label score xmin ymin xmax ymax difficult + self.detect_lod = [[1]] + self.detect = [[1, 0.2, 0.8, 0.1, 1.0, 0.3]] + + # label score true_pos false_pos + self.tf_pos_lod = [[3, 4]] + self.tf_pos = [[1, 0.2, 1, 0]] + + self.class_pos_count = [] + self.true_pos_lod = [[]] + self.true_pos = [[]] + self.false_pos_lod = [[]] + self.false_pos = [[]] + + if __name__ == '__main__': unittest.main() -- GitLab