diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..48308a11b4b313ec19b578110b9e369f4bfc52bf --- /dev/null +++ b/paddle/fluid/operators/detection_map_op.cc @@ -0,0 +1,184 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection_map_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class DetectionMAPOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("DetectRes"), + "Input(DetectRes) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumPosCount"), + "Output(AccumPosCount) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumTruePos"), + "Output(AccumTruePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumFalsePos"), + "Output(AccumFalsePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MAP"), + "Output(MAP) of DetectionMAPOp should not be null."); + + auto det_dims = ctx->GetInputDim("DetectRes"); + PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, + "The rank of Input(DetectRes) must be 2, " + "the shape is [N, 6]."); + PADDLE_ENFORCE_EQ(det_dims[1], 6UL, + "The shape is of Input(DetectRes) [N, 6]."); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "The rank of Input(Label) must be 2, " + "the shape is [N, 6]."); + PADDLE_ENFORCE_EQ(label_dims[1], 6UL, + "The shape is of Input(Label) [N, 6]."); + + if (ctx->HasInput("PosCount")) { + PADDLE_ENFORCE(ctx->HasInput("TruePos"), + "Input(TruePos) of DetectionMAPOp should not be null when " + "Input(TruePos) is not null."); + PADDLE_ENFORCE( + ctx->HasInput("FalsePos"), + "Input(FalsePos) of DetectionMAPOp should not be null when " + "Input(FalsePos) is not null."); + } + + ctx->SetOutputDim("MAP", framework::make_ddim({1})); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("DetectRes")->type()), + ctx.device_context()); + } +}; + +class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("DetectRes", + "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], M is the total " + "number of detect results in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected data."); + AddInput("Label", + "(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the" + "Labeled ground-truth data. Each row has 6 values: " + "[label, is_difficult, xmin, ymin, xmax, ymax], N is the total " + "number of ground-truth data in this mini-batch. For each " + "instance, the offsets in first dimension are called LoD, " + "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " + "means there is no ground-truth data."); + AddInput("PosCount", + "(Tensor) A tensor with shape [Ncls, 1], store the " + "input positive example count of each class, Ncls is the count of " + "input classification. " + "This input is used to pass the AccumPosCount generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. " + "When the input(PosCount) is empty, the cumulative " + "calculation is not carried out, and only the results of the " + "current mini-batch are calculated.") + .AsDispensable(); + AddInput("TruePos", + "(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the " + "input true positive example of each class." + "This input is used to pass the AccumTruePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") + .AsDispensable(); + AddInput("FalsePos", + "(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the " + "input false positive example of each class." + "This input is used to pass the AccumFalsePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") + .AsDispensable(); + AddOutput("AccumPosCount", + "(Tensor) A tensor with shape [Ncls, 1], store the " + "positive example count of each class. It combines the input " + "input(PosCount) and the positive example count computed from " + "input(Detection) and input(Label)."); + AddOutput("AccumTruePos", + "(LoDTensor) A LoDTensor with shape [Ntp', 2], store the " + "true positive example of each class. It combines the " + "input(TruePos) and the true positive examples computed from " + "input(Detection) and input(Label)."); + AddOutput("AccumFalsePos", + "(LoDTensor) A LoDTensor with shape [Nfp', 2], store the " + "false positive example of each class. It combines the " + "input(FalsePos) and the false positive examples computed from " + "input(Detection) and input(Label)."); + AddOutput("MAP", + "(Tensor) A tensor with shape [1], store the mAP evaluate " + "result of the detection."); + + AddAttr( + "overlap_threshold", + "(float) " + "The lower bound jaccard overlap threshold of detection output and " + "ground-truth data.") + .SetDefault(.3f); + AddAttr("evaluate_difficult", + "(bool, default true) " + "Switch to control whether the difficult data is evaluated.") + .SetDefault(true); + AddAttr("ap_type", + "(string, default 'integral') " + "The AP algorithm type, 'integral' or '11point'.") + .SetDefault("integral") + .InEnum({"integral", "11point"}) + .AddCustomChecker([](const std::string& ap_type) { + PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone, + "The ap_type should be 'integral' or '11point."); + }); + AddComment(R"DOC( +Detection mAP evaluate operator. +The general steps are as follows. First, calculate the true positive and + false positive according to the input of detection and labels, then + calculate the mAP evaluate value. + Supporting '11 point' and 'integral' mAP algorithm. Please get more information + from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp, + ops::DetectionMAPOpMaker); +REGISTER_OP_CPU_KERNEL( + detection_map, ops::DetectionMAPOpKernel, + ops::DetectionMAPOpKernel); diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0f5f588e9c448a6d84d388848aa5701f2b4882dd --- /dev/null +++ b/paddle/fluid/operators/detection_map_op.h @@ -0,0 +1,451 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +enum APType { kNone = 0, kIntegral, k11point }; + +APType GetAPType(std::string str) { + if (str == "integral") { + return APType::kIntegral; + } else if (str == "11point") { + return APType::k11point; + } else { + return APType::kNone; + } +} + +template +inline bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +inline void GetAccumulation(std::vector> in_pairs, + std::vector* accu_vec) { + std::stable_sort(in_pairs.begin(), in_pairs.end(), SortScorePairDescend); + accu_vec->clear(); + size_t sum = 0; + for (size_t i = 0; i < in_pairs.size(); ++i) { + auto count = in_pairs[i].second; + sum += count; + accu_vec->push_back(sum); + } +} + +template +class DetectionMAPOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_detect = ctx.Input("DetectRes"); + auto* in_label = ctx.Input("Label"); + auto* out_map = ctx.Output("MAP"); + + auto* in_pos_count = ctx.Input("PosCount"); + auto* in_true_pos = ctx.Input("TruePos"); + auto* in_false_pos = ctx.Input("FalsePos"); + + auto* out_pos_count = ctx.Output("AccumPosCount"); + auto* out_true_pos = ctx.Output("AccumTruePos"); + auto* out_false_pos = ctx.Output("AccumFalsePos"); + + float overlap_threshold = ctx.Attr("overlap_threshold"); + float evaluate_difficult = ctx.Attr("evaluate_difficult"); + auto ap_type = GetAPType(ctx.Attr("ap_type")); + + auto label_lod = in_label->lod(); + auto detect_lod = in_detect->lod(); + PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(), + "The batch_size of input(Label) and input(Detection) " + "must be the same."); + + std::vector>> gt_boxes; + std::vector>>> detect_boxes; + + GetBoxes(*in_label, *in_detect, gt_boxes, detect_boxes); + + std::map label_pos_count; + std::map>> true_pos; + std::map>> false_pos; + + if (in_pos_count != nullptr) { + GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count, + true_pos, false_pos); + } + + CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult, + overlap_threshold, label_pos_count, true_pos, + false_pos); + + T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); + + GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count, + *out_true_pos, *out_false_pos); + + T* map_data = out_map->mutable_data(ctx.GetPlace()); + map_data[0] = map; + } + + protected: + struct Box { + Box(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax), is_difficult(false) {} + + T xmin, ymin, xmax, ymax; + bool is_difficult; + }; + + inline T JaccardOverlap(const Box& box1, const Box& box2) const { + if (box2.xmin > box1.xmax || box2.xmax < box1.xmin || + box2.ymin > box1.ymax || box2.ymax < box1.ymin) { + return 0.0; + } else { + T inter_xmin = std::max(box1.xmin, box2.xmin); + T inter_ymin = std::max(box1.ymin, box2.ymin); + T inter_xmax = std::min(box1.xmax, box2.xmax); + T inter_ymax = std::min(box1.ymax, box2.ymax); + + T inter_width = inter_xmax - inter_xmin; + T inter_height = inter_ymax - inter_ymin; + T inter_area = inter_width * inter_height; + + T bbox_area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin); + T bbox_area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); + } + } + + void GetBoxes(const framework::LoDTensor& input_label, + const framework::LoDTensor& input_detect, + std::vector>>& gt_boxes, + std::vector>>>& + detect_boxes) const { + auto labels = framework::EigenTensor::From(input_label); + auto detect = framework::EigenTensor::From(input_detect); + + auto label_lod = input_label.lod(); + auto detect_lod = input_detect.lod(); + + int batch_size = label_lod[0].size() - 1; + auto label_index = label_lod[0]; + + for (int n = 0; n < batch_size; ++n) { + std::map> boxes; + for (int i = label_index[n]; i < label_index[n + 1]; ++i) { + Box box(labels(i, 2), labels(i, 3), labels(i, 4), labels(i, 5)); + int label = labels(i, 0); + auto is_difficult = labels(i, 1); + if (std::abs(is_difficult - 0.0) < 1e-6) + box.is_difficult = false; + else + box.is_difficult = true; + boxes[label].push_back(box); + } + gt_boxes.push_back(boxes); + } + + auto detect_index = detect_lod[0]; + for (int n = 0; n < batch_size; ++n) { + std::map>> boxes; + for (int i = detect_index[n]; i < detect_index[n + 1]; ++i) { + Box box(detect(i, 2), detect(i, 3), detect(i, 4), detect(i, 5)); + int label = detect(i, 0); + auto score = detect(i, 1); + boxes[label].push_back(std::make_pair(score, box)); + } + detect_boxes.push_back(boxes); + } + } + + void GetOutputPos( + const framework::ExecutionContext& ctx, + const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos, + framework::Tensor& output_pos_count, + framework::LoDTensor& output_true_pos, + framework::LoDTensor& output_false_pos) const { + int max_class_id = 0; + int true_pos_count = 0; + int false_pos_count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + if (label > max_class_id) max_class_id = label; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + true_pos_count += label_true_pos.size(); + false_pos_count += label_false_pos.size(); + } + + int* pos_count_data = output_pos_count.mutable_data( + framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace()); + T* true_pos_data = output_true_pos.mutable_data( + framework::make_ddim({true_pos_count, 2}), ctx.GetPlace()); + T* false_pos_data = output_false_pos.mutable_data( + framework::make_ddim({false_pos_count, 2}), ctx.GetPlace()); + true_pos_count = 0; + false_pos_count = 0; + std::vector true_pos_starts = {0}; + std::vector false_pos_starts = {0}; + for (int i = 0; i <= max_class_id; ++i) { + auto it_count = label_pos_count.find(i); + pos_count_data[i] = 0; + if (it_count != label_pos_count.end()) { + pos_count_data[i] = it_count->second; + } + auto it_true_pos = true_pos.find(i); + if (it_true_pos != true_pos.end()) { + const std::vector>& true_pos_vec = + it_true_pos->second; + for (const std::pair& tp : true_pos_vec) { + true_pos_data[true_pos_count * 2] = tp.first; + true_pos_data[true_pos_count * 2 + 1] = static_cast(tp.second); + true_pos_count++; + } + } + true_pos_starts.push_back(true_pos_count); + + auto it_false_pos = false_pos.find(i); + if (it_false_pos != false_pos.end()) { + const std::vector>& false_pos_vec = + it_false_pos->second; + for (const std::pair& fp : false_pos_vec) { + false_pos_data[false_pos_count * 2] = fp.first; + false_pos_data[false_pos_count * 2 + 1] = static_cast(fp.second); + false_pos_count++; + } + } + false_pos_starts.push_back(false_pos_count); + } + + framework::LoD true_pos_lod; + true_pos_lod.emplace_back(true_pos_starts); + framework::LoD false_pos_lod; + false_pos_lod.emplace_back(false_pos_starts); + + output_true_pos.set_lod(true_pos_lod); + output_false_pos.set_lod(false_pos_lod); + return; + } + + void GetInputPos( + const framework::Tensor& input_pos_count, + const framework::LoDTensor& input_true_pos, + const framework::LoDTensor& input_false_pos, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + constexpr T kEPS = static_cast(1e-6); + int class_number = input_pos_count.dims()[0]; + const int* pos_count_data = input_pos_count.data(); + for (int i = 0; i < class_number; ++i) { + label_pos_count[i] = pos_count_data[i]; + } + + auto SetData = [](const framework::LoDTensor& pos_tensor, + std::map>>& pos) { + const T* pos_data = pos_tensor.data(); + auto pos_data_lod = pos_tensor.lod(); + for (int i = 0; i < pos_data_lod.size(); ++i) { + for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) { + T score = pos_data[j * 2]; + int flag = 1; + if (pos_data[j * 2 + 1] < kEPS) flag = 0; + pos[i].push_back(std::make_pair(score, flag)); + } + } + }; + + SetData(input_true_pos, true_pos); + SetData(input_false_pos, false_pos); + return; + } + + void CalcTrueAndFalsePositive( + const std::vector>>& gt_boxes, + const std::vector>>>& + detect_boxes, + bool evaluate_difficult, float overlap_threshold, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + int batch_size = gt_boxes.size(); + for (int n = 0; n < batch_size; ++n) { + auto image_gt_boxes = gt_boxes[n]; + for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) { + size_t count = 0; + auto labeled_bboxes = it->second; + if (evaluate_difficult) { + count = labeled_bboxes.size(); + } else { + for (size_t i = 0; i < labeled_bboxes.size(); ++i) + if (!(labeled_bboxes[i].is_difficult)) ++count; + } + if (count == 0) { + continue; + } + int label = it->first; + if (label_pos_count.find(label) == label_pos_count.end()) { + label_pos_count[label] = count; + } else { + label_pos_count[label] += count; + } + } + } + + for (size_t n = 0; n < detect_boxes.size(); ++n) { + auto image_gt_boxes = gt_boxes[n]; + auto detections = detect_boxes[n]; + + if (image_gt_boxes.size() == 0) { + for (auto it = detections.begin(); it != detections.end(); ++it) { + auto pred_boxes = it->second; + int label = it->first; + for (size_t i = 0; i < pred_boxes.size(); ++i) { + auto score = pred_boxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + continue; + } + + for (auto it = detections.begin(); it != detections.end(); ++it) { + int label = it->first; + auto pred_boxes = it->second; + if (image_gt_boxes.find(label) == image_gt_boxes.end()) { + for (size_t i = 0; i < pred_boxes.size(); ++i) { + auto score = pred_boxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + continue; + } + + auto matched_bboxes = image_gt_boxes.find(label)->second; + std::vector visited(matched_bboxes.size(), false); + // Sort detections in descend order based on scores + std::sort(pred_boxes.begin(), pred_boxes.end(), + SortScorePairDescend); + for (size_t i = 0; i < pred_boxes.size(); ++i) { + T max_overlap = -1.0; + size_t max_idx = 0; + auto score = pred_boxes[i].first; + for (size_t j = 0; j < matched_bboxes.size(); ++j) { + T overlap = JaccardOverlap(pred_boxes[i].second, matched_bboxes[j]); + if (overlap > max_overlap) { + max_overlap = overlap; + max_idx = j; + } + } + if (max_overlap > overlap_threshold) { + bool match_evaluate_difficult = + evaluate_difficult || + (!evaluate_difficult && !matched_bboxes[max_idx].is_difficult); + if (match_evaluate_difficult) { + if (!visited[max_idx]) { + true_pos[label].push_back(std::make_pair(score, 1)); + false_pos[label].push_back(std::make_pair(score, 0)); + visited[max_idx] = true; + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } + } + } + + T CalcMAP( + APType ap_type, const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos) const { + T mAP = 0.0; + int count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + // Compute average precision. + std::vector tp_sum; + GetAccumulation(label_true_pos, &tp_sum); + std::vector fp_sum; + GetAccumulation(label_false_pos, &fp_sum); + std::vector precision, recall; + size_t num = tp_sum.size(); + // Compute Precision. + for (size_t i = 0; i < num; ++i) { + precision.push_back(static_cast(tp_sum[i]) / + static_cast(tp_sum[i] + fp_sum[i])); + recall.push_back(static_cast(tp_sum[i]) / label_num_pos); + } + // VOC2007 style + if (ap_type == APType::k11point) { + std::vector max_precisions(11, 0.0); + int start_idx = num - 1; + for (int j = 10; j >= 0; --j) + for (int i = start_idx; i >= 0; --i) { + if (recall[i] < j / 10.) { + start_idx = i; + if (j > 0) max_precisions[j - 1] = max_precisions[j]; + break; + } else { + if (max_precisions[j] < precision[i]) + max_precisions[j] = precision[i]; + } + } + for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11; + ++count; + } else if (ap_type == APType::kIntegral) { + // Nature integral + float average_precisions = 0.; + float prev_recall = 0.; + for (size_t i = 0; i < num; ++i) { + if (fabs(recall[i] - prev_recall) > 1e-6) + average_precisions += precision[i] * fabs(recall[i] - prev_recall); + prev_recall = recall[i]; + } + mAP += average_precisions; + ++count; + } else { + LOG(FATAL) << "Unkown ap version: " << ap_type; + } + } + if (count != 0) mAP /= count; + return mAP * 100; + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py new file mode 100644 index 0000000000000000000000000000000000000000..70ccd885d89f245df492bad0fbcecc093dc1928c --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -0,0 +1,265 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import collections +import math +from op_test import OpTest + + +class TestDetectionMAPOp(OpTest): + def set_data(self): + self.init_test_case() + + self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)] + self.label = np.array(self.label).astype('float32') + self.detect = np.array(self.detect).astype('float32') + self.mAP = np.array(self.mAP).astype('float32') + + if (len(self.class_pos_count) > 0): + self.class_pos_count = np.array(self.class_pos_count).astype( + 'int32') + self.true_pos = np.array(self.true_pos).astype('float32') + self.false_pos = np.array(self.false_pos).astype('float32') + + self.inputs = { + 'Label': (self.label, self.label_lod), + 'DetectRes': (self.detect, self.detect_lod), + 'PosCount': self.class_pos_count, + 'TruePos': (self.true_pos, self.true_pos_lod), + 'FalsePos': (self.false_pos, self.false_pos_lod) + } + else: + self.inputs = { + 'Label': (self.label, self.label_lod), + 'DetectRes': (self.detect, self.detect_lod), + } + + self.attrs = { + 'overlap_threshold': self.overlap_threshold, + 'evaluate_difficult': self.evaluate_difficult, + 'ap_type': self.ap_type + } + + self.out_class_pos_count = np.array(self.out_class_pos_count).astype( + 'int') + self.out_true_pos = np.array(self.out_true_pos).astype('float32') + self.out_false_pos = np.array(self.out_false_pos).astype('float32') + + self.outputs = { + 'MAP': self.mAP, + 'AccumPosCount': self.out_class_pos_count, + 'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod), + 'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod) + } + + def init_test_case(self): + self.overlap_threshold = 0.3 + self.evaluate_difficult = True + self.ap_type = "integral" + + self.label_lod = [[0, 2, 4]] + # label difficult xmin ymin xmax ymax + self.label = [[1, 0, 0.1, 0.1, 0.3, 0.3], [1, 1, 0.6, 0.6, 0.8, 0.8], + [2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]] + + # label score xmin ymin xmax ymax difficult + self.detect_lod = [[0, 3, 7]] + self.detect = [ + [1, 0.3, 0.1, 0.0, 0.4, 0.3], [1, 0.7, 0.0, 0.1, 0.2, 0.3], + [1, 0.9, 0.7, 0.6, 0.8, 0.8], [2, 0.8, 0.2, 0.1, 0.4, 0.4], + [2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 0.2, 0.8, 0.1, 1.0, 0.3], + [3, 0.2, 0.8, 0.1, 1.0, 0.3] + ] + + # label score true_pos false_pos + self.tf_pos_lod = [[0, 3, 7]] + self.tf_pos = [[1, 0.9, 1, 0], [1, 0.7, 1, 0], [1, 0.3, 0, 1], + [1, 0.2, 1, 0], [2, 0.8, 0, 1], [2, 0.1, 1, 0], + [3, 0.2, 0, 1]] + + self.class_pos_count = [] + self.true_pos_lod = [[]] + self.true_pos = [[]] + self.false_pos_lod = [[]] + self.false_pos = [[]] + + def calc_map(self, tf_pos, tf_pos_lod): + mAP = 0.0 + count = 0 + + def get_input_pos(class_pos_count, true_pos, true_pos_lod, false_pos, + false_pos_lod): + class_pos_count_dict = collections.Counter() + true_pos_dict = collections.defaultdict(list) + false_pos_dict = collections.defaultdict(list) + for i, count in enumerate(class_pos_count): + class_pos_count_dict[i] = count + + for i in range(len(true_pos_lod[0]) - 1): + start = true_pos_lod[0][i] + end = true_pos_lod[0][i + 1] + for j in range(start, end): + true_pos_dict[i].append(true_pos[j]) + + for i in range(len(false_pos_lod[0]) - 1): + start = false_pos_lod[0][i] + end = false_pos_lod[0][i + 1] + for j in range(start, end): + false_pos_dict[i].append(false_pos[j]) + + return class_pos_count_dict, true_pos_dict, false_pos_dict + + def get_output_pos(label_count, true_pos, false_pos): + max_label = 0 + for (label, label_pos_num) in label_count.items(): + if max_label < label: + max_label = label + + label_number = max_label + 1 + + out_class_pos_count = [] + out_true_pos_lod = [0] + out_true_pos = [] + out_false_pos_lod = [0] + out_false_pos = [] + + for i in range(label_number): + out_class_pos_count.append([label_count[i]]) + true_pos_list = true_pos[i] + out_true_pos += true_pos_list + out_true_pos_lod.append(len(out_true_pos)) + false_pos_list = false_pos[i] + out_false_pos += false_pos_list + out_false_pos_lod.append(len(out_false_pos)) + + return out_class_pos_count, out_true_pos, [ + out_true_pos_lod + ], out_false_pos, [out_false_pos_lod] + + def get_accumulation(pos_list): + sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True) + sum = 0 + accu_list = [] + for (score, count) in sorted_list: + sum += count + accu_list.append(sum) + return accu_list + + label_count, true_pos, false_pos = get_input_pos( + self.class_pos_count, self.true_pos, self.true_pos_lod, + self.false_pos, self.false_pos_lod) + for (label, difficult, xmin, ymin, xmax, ymax) in self.label: + if self.evaluate_difficult: + label_count[label] += 1 + elif not difficult: + label_count[label] += 1 + + true_pos = collections.defaultdict(list) + false_pos = collections.defaultdict(list) + for (label, score, tp, fp) in tf_pos: + true_pos[label].append([score, tp]) + false_pos[label].append([score, fp]) + + for (label, label_pos_num) in label_count.items(): + if label_pos_num == 0 or label not in true_pos: continue + label_true_pos = true_pos[label] + label_false_pos = false_pos[label] + + accu_tp_sum = get_accumulation(label_true_pos) + accu_fp_sum = get_accumulation(label_false_pos) + + precision = [] + recall = [] + + for i in range(len(accu_tp_sum)): + precision.append( + float(accu_tp_sum[i]) / + float(accu_tp_sum[i] + accu_fp_sum[i])) + recall.append(float(accu_tp_sum[i]) / label_pos_num) + + if self.ap_type == "11point": + max_precisions = [0.0] * 11 + start_idx = len(accu_tp_sum) - 1 + for j in range(10, -1, -1): + for i in range(start_idx, -1, -1): + if recall[i] < float(j) / 10.0: + start_idx = i + if j > 0: + max_precisions[j - 1] = max_precisions[j] + break + else: + if max_precisions[j] < precision[i]: + max_precisions[j] = precision[i] + for j in range(10, -1, -1): + mAP += max_precisions[j] / 11 + count += 1 + elif self.ap_type == "integral": + average_precisions = 0.0 + prev_recall = 0.0 + for i in range(len(accu_tp_sum)): + if math.fabs(recall[i] - prev_recall) > 1e-6: + average_precisions += precision[i] * \ + math.fabs(recall[i] - prev_recall) + prev_recall = recall[i] + + mAP += average_precisions + count += 1 + self.out_class_pos_count, self.out_true_pos, self.out_true_pos_lod, self.out_false_pos, self.out_false_pos_lod = get_output_pos( + label_count, true_pos, false_pos) + if count != 0: + mAP /= count + return mAP * 100.0 + + def setUp(self): + self.op_type = "detection_map" + self.set_data() + + def test_check_output(self): + self.check_output() + + +class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpSkipDiff, self).init_test_case() + + self.evaluate_difficult = False + + self.tf_pos_lod = [[0, 2, 6]] + # label score true_pos false_pos + self.tf_pos = [[1, 0.7, 1, 0], [1, 0.3, 0, 1], [1, 0.2, 1, 0], + [2, 0.8, 0, 1], [2, 0.1, 1, 0], [3, 0.2, 0, 1]] + + +class TestDetectionMAPOp11Point(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOp11Point, self).init_test_case() + + self.ap_type = "11point" + + +class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpMultiBatch, self).init_test_case() + self.class_pos_count = [0, 2, 1] + self.true_pos_lod = [[0, 0, 3, 5]] + self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]] + self.false_pos_lod = [[0, 0, 3, 5]] + self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]] + + +if __name__ == '__main__': + unittest.main()