diff --git a/paddle/operators/detection_map_op.cc b/paddle/operators/detection_map_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b59d3bfad96c77e48bf443619978573d02c0e8d9 --- /dev/null +++ b/paddle/operators/detection_map_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detection_map_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class DetectionMAPOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto map_dim = framework::make_ddim({1}); + ctx->SetOutputDim("MAP", map_dim); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Label")->type()), + ctx.device_context()); + } +}; + +class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DetectionMAPOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Detect", "The detection output."); + AddInput("Label", "The label data."); + AddOutput("MAP", "The MAP evaluate result of the detection."); + + AddAttr("overlap_threshold", "The overlap threshold.") + .SetDefault(.3f); + AddAttr("evaluate_difficult", + "Switch to control whether the difficult data is evaluated.") + .SetDefault(true); + AddAttr("ap_type", + "The AP algorithm type, 'Integral' or '11point'.") + .SetDefault("Integral"); + + AddComment(R"DOC( +Detection MAP Operator. + +Detection MAP evaluator for SSD(Single Shot MultiBox Detector) algorithm. +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp, + ops::DetectionMAPOpMaker); +REGISTER_OP_CPU_KERNEL( + detection_map, ops::DetectionMAPOpKernel, + ops::DetectionMAPOpKernel); diff --git a/paddle/operators/detection_map_op.cu b/paddle/operators/detection_map_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab9a992c363169588d587965e93dfb3d147678d8 --- /dev/null +++ b/paddle/operators/detection_map_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detection_map_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + detection_map, ops::DetectionMAPOpKernel, + ops::DetectionMAPOpKernel); diff --git a/paddle/operators/detection_map_op.h b/paddle/operators/detection_map_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3e862abda64eb56361e59d33cf129de0bfc6118b --- /dev/null +++ b/paddle/operators/detection_map_op.h @@ -0,0 +1,316 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +inline void GetAccumulation(std::vector> in_pairs, + std::vector* accu_vec) { + std::stable_sort(in_pairs.begin(), in_pairs.end(), + math::SortScorePairDescend); + accu_vec->clear(); + size_t sum = 0; + for (size_t i = 0; i < in_pairs.size(); ++i) { + // auto score = in_pairs[i].first; + auto count = in_pairs[i].second; + sum += count; + accu_vec->push_back(sum); + } +} + +template +class DetectionMAPOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_label = ctx.Input("Label"); + auto* input_detect = ctx.Input("Detect"); + auto* map_out = ctx.Output("MAP"); + + float overlap_threshold = ctx.Attr("overlap_threshold"); + float evaluate_difficult = ctx.Attr("evaluate_difficult"); + std::string ap_type = ctx.Attr("ap_type"); + + auto label_lod = input_label->lod(); + PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, + "Only support one level sequence now."); + auto batch_size = label_lod[0].size() - 1; + + std::vector>>> gt_bboxes; + + std::vector< + std::map>>>> + detect_bboxes; + + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::LoDTensor input_label_cpu; + framework::Tensor input_detect_cpu; + input_label_cpu.set_lod(input_label->lod()); + input_label_cpu.Resize(input_label->dims()); + input_detect_cpu.Resize(input_detect->dims()); + input_label_cpu.mutable_data(platform::CPUPlace()); + input_detect_cpu.mutable_data(platform::CPUPlace()); + framework::CopyFrom(*input_label, platform::CPUPlace(), + ctx.device_context(), &input_label_cpu); + framework::CopyFrom(*input_detect, platform::CPUPlace(), + ctx.device_context(), &input_detect_cpu); + GetBBoxes(input_label_cpu, input_detect_cpu, gt_bboxes, detect_bboxes); + } else { + GetBBoxes(*input_label, *input_detect, gt_bboxes, detect_bboxes); + } + + std::map label_pos_count; + std::map>> true_pos; + std::map>> false_pos; + + CalcTrueAndFalsePositive(batch_size, evaluate_difficult, overlap_threshold, + gt_bboxes, detect_bboxes, label_pos_count, + true_pos, false_pos); + + T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); + + T* map_data = nullptr; + framework::Tensor map_cpu; + map_out->mutable_data(ctx.GetPlace()); + if (platform::is_gpu_place(ctx.GetPlace())) { + map_data = map_cpu.mutable_data(map_out->dims(), platform::CPUPlace()); + map_data[0] = map; + framework::CopyFrom(map_cpu, platform::CPUPlace(), ctx.device_context(), + map_out); + } else { + map_data = map_out->mutable_data(ctx.GetPlace()); + map_data[0] = map; + } + } + + protected: + void GetBBoxes( + const framework::LoDTensor& input_label, + const framework::Tensor& input_detect, + std::vector>>>& + gt_bboxes, + std::vector< + std::map>>>>& + detect_bboxes) const { + const T* label_data = input_label.data(); + const T* detect_data = input_detect.data(); + + auto label_lod = input_label.lod(); + auto batch_size = label_lod[0].size() - 1; + auto label_index = label_lod[0]; + + for (size_t n = 0; n < batch_size; ++n) { + std::map>> bboxes; + for (int i = label_index[n]; i < label_index[n + 1]; ++i) { + std::vector> bbox; + math::GetBBoxFromLabelData(label_data + i * 6, 1, bbox); + int label = static_cast(label_data[i * 6]); + bboxes[label].push_back(bbox[0]); + } + gt_bboxes.push_back(bboxes); + } + + size_t n = 0; + size_t detect_box_count = input_detect.dims()[0]; + for (size_t img_id = 0; img_id < batch_size; ++img_id) { + std::map>>> bboxes; + size_t cur_img_id = static_cast((detect_data + n * 7)[0]); + while (cur_img_id == img_id && n < detect_box_count) { + std::vector label; + std::vector score; + std::vector> bbox; + math::GetBBoxFromDetectData(detect_data + n * 7, 1, label, score, + bbox); + bboxes[label[0]].push_back(std::make_pair(score[0], bbox[0])); + ++n; + cur_img_id = static_cast((detect_data + n * 7)[0]); + } + detect_bboxes.push_back(bboxes); + } + } + + void CalcTrueAndFalsePositive( + size_t batch_size, bool evaluate_difficult, float overlap_threshold, + const std::vector>>>& + gt_bboxes, + const std::vector< + std::map>>>>& + detect_bboxes, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + for (size_t n = 0; n < batch_size; ++n) { + auto image_gt_bboxes = gt_bboxes[n]; + for (auto it = image_gt_bboxes.begin(); it != image_gt_bboxes.end(); + ++it) { + size_t count = 0; + auto labeled_bboxes = it->second; + if (evaluate_difficult) { + count = labeled_bboxes.size(); + } else { + for (size_t i = 0; i < labeled_bboxes.size(); ++i) + if (!(labeled_bboxes[i].is_difficult)) ++count; + } + if (count == 0) { + continue; + } + int label = it->first; + if (label_pos_count.find(label) == label_pos_count.end()) { + label_pos_count[label] = count; + } else { + label_pos_count[label] += count; + } + } + } + + for (size_t n = 0; n < detect_bboxes.size(); ++n) { + auto image_gt_bboxes = gt_bboxes[n]; + auto detections = detect_bboxes[n]; + + if (image_gt_bboxes.size() == 0) { + for (auto it = detections.begin(); it != detections.end(); ++it) { + auto pred_bboxes = it->second; + int label = it->first; + for (size_t i = 0; i < pred_bboxes.size(); ++i) { + auto score = pred_bboxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + continue; + } + + for (auto it = detections.begin(); it != detections.end(); ++it) { + int label = it->first; + auto pred_bboxes = it->second; + if (image_gt_bboxes.find(label) == image_gt_bboxes.end()) { + for (size_t i = 0; i < pred_bboxes.size(); ++i) { + auto score = pred_bboxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + continue; + } + + auto matched_bboxes = image_gt_bboxes.find(label)->second; + std::vector visited(matched_bboxes.size(), false); + // Sort detections in descend order based on scores + std::sort(pred_bboxes.begin(), pred_bboxes.end(), + math::SortScorePairDescend>); + for (size_t i = 0; i < pred_bboxes.size(); ++i) { + float max_overlap = -1.0; + size_t max_idx = 0; + auto score = pred_bboxes[i].first; + for (size_t j = 0; j < matched_bboxes.size(); ++j) { + float overlap = + JaccardOverlap(pred_bboxes[i].second, matched_bboxes[j]); + if (overlap > max_overlap) { + max_overlap = overlap; + max_idx = j; + } + } + if (max_overlap > overlap_threshold) { + bool match_evaluate_difficult = + evaluate_difficult || + (!evaluate_difficult && !matched_bboxes[max_idx].is_difficult); + if (match_evaluate_difficult) { + if (!visited[max_idx]) { + true_pos[label].push_back(std::make_pair(score, 1)); + false_pos[label].push_back(std::make_pair(score, 0)); + visited[max_idx] = true; + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } + } + } + + T CalcMAP( + std::string ap_type, const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos) const { + T mAP = 0.0; + int count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + // Compute average precision. + std::vector tp_sum; + GetAccumulation(label_true_pos, &tp_sum); + std::vector fp_sum; + GetAccumulation(label_false_pos, &fp_sum); + std::vector precision, recall; + size_t num = tp_sum.size(); + // Compute Precision. + for (size_t i = 0; i < num; ++i) { + // CHECK_LE(tpCumSum[i], labelNumPos); + precision.push_back(static_cast(tp_sum[i]) / + static_cast(tp_sum[i] + fp_sum[i])); + recall.push_back(static_cast(tp_sum[i]) / label_num_pos); + } + // VOC2007 style + if (ap_type == "11point") { + std::vector max_precisions(11, 0.0); + int start_idx = num - 1; + for (int j = 10; j >= 0; --j) + for (int i = start_idx; i >= 0; --i) { + if (recall[i] < j / 10.) { + start_idx = i; + if (j > 0) max_precisions[j - 1] = max_precisions[j]; + break; + } else { + if (max_precisions[j] < precision[i]) + max_precisions[j] = precision[i]; + } + } + for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11; + ++count; + } else if (ap_type == "Integral") { + // Nature integral + float average_precisions = 0.; + float prev_recall = 0.; + for (size_t i = 0; i < num; ++i) { + if (fabs(recall[i] - prev_recall) > 1e-6) + average_precisions += precision[i] * fabs(recall[i] - prev_recall); + prev_recall = recall[i]; + } + mAP += average_precisions; + ++count; + } else { + LOG(FATAL) << "Unkown ap version: " << ap_type; + } + } + if (count != 0) mAP /= count; + return mAP * 100; + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detection_util.cc b/paddle/operators/math/detection_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4131a0cb0ef086301a9938382ae5d837508ea07d --- /dev/null +++ b/paddle/operators/math/detection_util.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math {} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detection_util.cu b/paddle/operators/math/detection_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..d2bb992396197cd76c31cee11c804257adc51357 --- /dev/null +++ b/paddle/operators/math/detection_util.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math {} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h new file mode 100644 index 0000000000000000000000000000000000000000..2a4dadc545eb2b67e5819e5fe65fc64fcb35e69c --- /dev/null +++ b/paddle/operators/math/detection_util.h @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct BBox { + BBox(T x_min, T y_min, T x_max, T y_max) + : x_min(x_min), + y_min(y_min), + x_max(x_max), + y_max(y_max), + is_difficult(false) {} + + BBox() {} + + T get_width() const { return x_max - x_min; } + + T get_height() const { return y_max - y_min; } + + T get_center_x() const { return (x_min + x_max) / 2; } + + T get_center_y() const { return (y_min + y_max) / 2; } + + T get_area() const { return get_width() * get_height(); } + + // coordinate of bounding box + T x_min; + T y_min; + T x_max; + T y_max; + // whether difficult object (e.g. object with heavy occlusion is difficult) + bool is_difficult; +}; + +template +void GetBBoxFromDetectData(const T* detect_data, const size_t num_bboxes, + std::vector& labels, std::vector& scores, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + labels.resize(out_offset + num_bboxes); + scores.resize(out_offset + num_bboxes); + bboxes.resize(out_offset + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + labels[out_offset + i] = *(detect_data + i * 7 + 1); + scores[out_offset + i] = *(detect_data + i * 7 + 2); + BBox bbox; + bbox.x_min = *(detect_data + i * 7 + 3); + bbox.y_min = *(detect_data + i * 7 + 4); + bbox.x_max = *(detect_data + i * 7 + 5); + bbox.y_max = *(detect_data + i * 7 + 6); + bboxes[out_offset + i] = bbox; + }; +} + +template +void GetBBoxFromLabelData(const T* label_data, const size_t num_bboxes, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + bboxes.resize(bboxes.size() + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + BBox bbox; + bbox.x_min = *(label_data + i * 6 + 1); + bbox.y_min = *(label_data + i * 6 + 2); + bbox.x_max = *(label_data + i * 6 + 3); + bbox.y_max = *(label_data + i * 6 + 4); + T is_difficult = *(label_data + i * 6 + 5); + if (std::abs(is_difficult - 0.0) < 1e-6) + bbox.is_difficult = false; + else + bbox.is_difficult = true; + bboxes[out_offset + i] = bbox; + } +} + +template +inline float JaccardOverlap(const BBox& bbox1, const BBox& bbox2) { + if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || + bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { + return 0.0; + } else { + float inter_x_min = std::max(bbox1.x_min, bbox2.x_min); + float inter_y_min = std::max(bbox1.y_min, bbox2.y_min); + float inter_x_max = std::min(bbox1.x_max, bbox2.x_max); + float inter_y_max = std::min(bbox1.y_max, bbox2.y_max); + + float inter_width = inter_x_max - inter_x_min; + float inter_height = inter_y_max - inter_y_min; + float inter_area = inter_width * inter_height; + + float bbox_area1 = bbox1.get_area(); + float bbox_area2 = bbox2.get_area(); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); + } +} + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +// template <> +// bool SortScorePairDescend(const std::pair& pair1, +// const std::pair& pair2) { +// return pair1.first > pair2.first; +// } + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py new file mode 100644 index 0000000000000000000000000000000000000000..50ce3afbb95064a59b63d9f1606e36d3001a7e08 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -0,0 +1,155 @@ +import unittest +import numpy as np +import sys +import collections +import math +from op_test import OpTest + + +class TestDetectionMAPOp(OpTest): + def set_data(self): + self.init_test_case() + + self.mAP = [self.calc_map(self.tf_pos)] + self.label = np.array(self.label).astype('float32') + self.detect = np.array(self.detect).astype('float32') + self.mAP = np.array(self.mAP).astype('float32') + + self.inputs = { + 'Label': (self.label, self.label_lod), + 'Detect': self.detect + } + + self.attrs = { + 'overlap_threshold': self.overlap_threshold, + 'evaluate_difficult': self.evaluate_difficult, + 'ap_type': self.ap_type + } + + self.outputs = {'MAP': self.mAP} + + def init_test_case(self): + self.overlap_threshold = 0.3 + self.evaluate_difficult = True + self.ap_type = "Integral" + + self.label_lod = [[0, 2, 4]] + # label xmin ymin xmax ymax difficult + self.label = [[1, 0.1, 0.1, 0.3, 0.3, 0], [1, 0.6, 0.6, 0.8, 0.8, 1], + [2, 0.3, 0.3, 0.6, 0.5, 0], [1, 0.7, 0.1, 0.9, 0.3, 0]] + + # image_id label score xmin ymin xmax ymax difficult + self.detect = [ + [0, 1, 0.3, 0.1, 0.0, 0.4, 0.3], [0, 1, 0.7, 0.0, 0.1, 0.2, 0.3], + [0, 1, 0.9, 0.7, 0.6, 0.8, 0.8], [1, 2, 0.8, 0.2, 0.1, 0.4, 0.4], + [1, 2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 1, 0.2, 0.8, 0.1, 1.0, 0.3], + [1, 3, 0.2, 0.8, 0.1, 1.0, 0.3] + ] + + # image_id label score false_pos false_pos + # [-1, 1, 3, -1, -1], + # [-1, 2, 1, -1, -1] + self.tf_pos = [[0, 1, 0.9, 1, 0], [0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], + [1, 1, 0.2, 1, 0], [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], + [1, 3, 0.2, 0, 1]] + + def calc_map(self, tf_pos): + mAP = 0.0 + count = 0 + + class_pos_count = {} + true_pos = {} + false_pos = {} + + def get_accumulation(pos_list): + sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True) + sum = 0 + accu_list = [] + for (score, count) in sorted_list: + sum += count + accu_list.append(sum) + return accu_list + + label_count = collections.Counter() + for (label, xmin, ymin, xmax, ymax, difficult) in self.label: + if self.evaluate_difficult: + label_count[label] += 1 + elif not difficult: + label_count[label] += 1 + + true_pos = collections.defaultdict(list) + false_pos = collections.defaultdict(list) + for (image_id, label, score, tp, fp) in tf_pos: + true_pos[label].append([score, tp]) + false_pos[label].append([score, fp]) + + for (label, label_pos_num) in label_count.items(): + if label_pos_num == 0 or label not in true_pos: + continue + + label_true_pos = true_pos[label] + label_false_pos = false_pos[label] + + accu_tp_sum = get_accumulation(label_true_pos) + accu_fp_sum = get_accumulation(label_false_pos) + + precision = [] + recall = [] + + for i in range(len(accu_tp_sum)): + precision.append( + float(accu_tp_sum[i]) / + float(accu_tp_sum[i] + accu_fp_sum[i])) + recall.append(float(accu_tp_sum[i]) / label_pos_num) + + if self.ap_type == "11point": + max_precisions = [11.0, 0.0] + start_idx = len(accu_tp_sum) - 1 + for j in range(10, 0, -1): + for i in range(start_idx, 0, -1): + if recall[i] < j / 10.0: + start_idx = i + if j > 0: + max_precisions[j - 1] = max_precisions[j] + break + else: + if max_precisions[j] < accu_precision[i]: + max_precisions[j] = accu_precision[i] + for j in range(10, 0, -1): + mAP += max_precisions[j] / 11 + count += 1 + elif self.ap_type == "Integral": + average_precisions = 0.0 + prev_recall = 0.0 + for i in range(len(accu_tp_sum)): + if math.fabs(recall[i] - prev_recall) > 1e-6: + average_precisions += precision[i] * \ + math.fabs(recall[i] - prev_recall) + prev_recall = recall[i] + + mAP += average_precisions + count += 1 + + if count != 0: mAP /= count + return mAP * 100.0 + + def setUp(self): + self.op_type = "detection_map" + self.set_data() + + def test_check_output(self): + self.check_output() + + +class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpSkipDiff, self).init_test_case() + + self.evaluate_difficult = False + + self.tf_pos = [[0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], [1, 1, 0.2, 1, 0], + [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], [1, 3, 0.2, 0, 1]] + + +if __name__ == '__main__': + unittest.main()