未验证 提交 a824da91 编写于 作者: W Wang Hao 提交者: GitHub

Merge pull request #6588 from wanghaox/detection_map

detection map evaluator for SSD
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection_map_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class DetectionMAPOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
"Input(DetectRes) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumPosCount"),
"Output(AccumPosCount) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumTruePos"),
"Output(AccumTruePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumFalsePos"),
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
"Output(MAP) of DetectionMAPOp should not be null.");
auto det_dims = ctx->GetInputDim("DetectRes");
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
"The rank of Input(DetectRes) must be 2, "
"the shape is [N, 6].");
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
"The shape is of Input(DetectRes) [N, 6].");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The rank of Input(Label) must be 2, "
"the shape is [N, 6].");
PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
"The shape is of Input(Label) [N, 6].");
if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(TruePos) is not null.");
PADDLE_ENFORCE(
ctx->HasInput("FalsePos"),
"Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(FalsePos) is not null.");
}
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::Tensor>("DetectRes")->type()),
ctx.device_context());
}
};
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detect results in this mini-batch. For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("Label",
"(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the"
"Labeled ground-truth data. Each row has 6 values: "
"[label, is_difficult, xmin, ymin, xmax, ymax], N is the total "
"number of ground-truth data in this mini-batch. For each "
"instance, the offsets in first dimension are called LoD, "
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
"means there is no ground-truth data.");
AddInput("PosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the "
"input positive example count of each class, Ncls is the count of "
"input classification. "
"This input is used to pass the AccumPosCount generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. "
"When the input(PosCount) is empty, the cumulative "
"calculation is not carried out, and only the results of the "
"current mini-batch are calculated.")
.AsDispensable();
AddInput("TruePos",
"(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the "
"input true positive example of each class."
"This input is used to pass the AccumTruePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddInput("FalsePos",
"(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the "
"input false positive example of each class."
"This input is used to pass the AccumFalsePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddOutput("AccumPosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the "
"positive example count of each class. It combines the input "
"input(PosCount) and the positive example count computed from "
"input(Detection) and input(Label).");
AddOutput("AccumTruePos",
"(LoDTensor) A LoDTensor with shape [Ntp', 2], store the "
"true positive example of each class. It combines the "
"input(TruePos) and the true positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("AccumFalsePos",
"(LoDTensor) A LoDTensor with shape [Nfp', 2], store the "
"false positive example of each class. It combines the "
"input(FalsePos) and the false positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("MAP",
"(Tensor) A tensor with shape [1], store the mAP evaluate "
"result of the detection.");
AddAttr<float>(
"overlap_threshold",
"(float) "
"The lower bound jaccard overlap threshold of detection output and "
"ground-truth data.")
.SetDefault(.3f);
AddAttr<bool>("evaluate_difficult",
"(bool, default true) "
"Switch to control whether the difficult data is evaluated.")
.SetDefault(true);
AddAttr<std::string>("ap_type",
"(string, default 'integral') "
"The AP algorithm type, 'integral' or '11point'.")
.SetDefault("integral")
.InEnum({"integral", "11point"})
.AddCustomChecker([](const std::string& ap_type) {
PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone,
"The ap_type should be 'integral' or '11point.");
});
AddComment(R"DOC(
Detection mAP evaluate operator.
The general steps are as follows. First, calculate the true positive and
false positive according to the input of detection and labels, then
calculate the mAP evaluate value.
Supporting '11 point' and 'integral' mAP algorithm. Please get more information
from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp,
ops::DetectionMAPOpMaker);
REGISTER_OP_CPU_KERNEL(
detection_map, ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, float>,
ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
enum APType { kNone = 0, kIntegral, k11point };
APType GetAPType(std::string str) {
if (str == "integral") {
return APType::kIntegral;
} else if (str == "11point") {
return APType::k11point;
} else {
return APType::kNone;
}
}
template <typename T>
inline bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <typename T>
inline void GetAccumulation(std::vector<std::pair<T, int>> in_pairs,
std::vector<int>* accu_vec) {
std::stable_sort(in_pairs.begin(), in_pairs.end(), SortScorePairDescend<int>);
accu_vec->clear();
size_t sum = 0;
for (size_t i = 0; i < in_pairs.size(); ++i) {
auto count = in_pairs[i].second;
sum += count;
accu_vec->push_back(sum);
}
}
template <typename Place, typename T>
class DetectionMAPOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_detect = ctx.Input<framework::LoDTensor>("DetectRes");
auto* in_label = ctx.Input<framework::LoDTensor>("Label");
auto* out_map = ctx.Output<framework::Tensor>("MAP");
auto* in_pos_count = ctx.Input<framework::Tensor>("PosCount");
auto* in_true_pos = ctx.Input<framework::LoDTensor>("TruePos");
auto* in_false_pos = ctx.Input<framework::LoDTensor>("FalsePos");
auto* out_pos_count = ctx.Output<framework::Tensor>("AccumPosCount");
auto* out_true_pos = ctx.Output<framework::LoDTensor>("AccumTruePos");
auto* out_false_pos = ctx.Output<framework::LoDTensor>("AccumFalsePos");
float overlap_threshold = ctx.Attr<float>("overlap_threshold");
float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult");
auto ap_type = GetAPType(ctx.Attr<std::string>("ap_type"));
auto label_lod = in_label->lod();
auto detect_lod = in_detect->lod();
PADDLE_ENFORCE_EQ(label_lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(),
"The batch_size of input(Label) and input(Detection) "
"must be the same.");
std::vector<std::map<int, std::vector<Box>>> gt_boxes;
std::vector<std::map<int, std::vector<std::pair<T, Box>>>> detect_boxes;
GetBoxes(*in_label, *in_detect, gt_boxes, detect_boxes);
std::map<int, int> label_pos_count;
std::map<int, std::vector<std::pair<T, int>>> true_pos;
std::map<int, std::vector<std::pair<T, int>>> false_pos;
if (in_pos_count != nullptr) {
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
true_pos, false_pos);
}
CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult,
overlap_threshold, label_pos_count, true_pos,
false_pos);
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos);
GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count,
*out_true_pos, *out_false_pos);
T* map_data = out_map->mutable_data<T>(ctx.GetPlace());
map_data[0] = map;
}
protected:
struct Box {
Box(T xmin, T ymin, T xmax, T ymax)
: xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax), is_difficult(false) {}
T xmin, ymin, xmax, ymax;
bool is_difficult;
};
inline T JaccardOverlap(const Box& box1, const Box& box2) const {
if (box2.xmin > box1.xmax || box2.xmax < box1.xmin ||
box2.ymin > box1.ymax || box2.ymax < box1.ymin) {
return 0.0;
} else {
T inter_xmin = std::max(box1.xmin, box2.xmin);
T inter_ymin = std::max(box1.ymin, box2.ymin);
T inter_xmax = std::min(box1.xmax, box2.xmax);
T inter_ymax = std::min(box1.ymax, box2.ymax);
T inter_width = inter_xmax - inter_xmin;
T inter_height = inter_ymax - inter_ymin;
T inter_area = inter_width * inter_height;
T bbox_area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
T bbox_area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
return inter_area / (bbox_area1 + bbox_area2 - inter_area);
}
}
void GetBoxes(const framework::LoDTensor& input_label,
const framework::LoDTensor& input_detect,
std::vector<std::map<int, std::vector<Box>>>& gt_boxes,
std::vector<std::map<int, std::vector<std::pair<T, Box>>>>&
detect_boxes) const {
auto labels = framework::EigenTensor<T, 2>::From(input_label);
auto detect = framework::EigenTensor<T, 2>::From(input_detect);
auto label_lod = input_label.lod();
auto detect_lod = input_detect.lod();
int batch_size = label_lod[0].size() - 1;
auto label_index = label_lod[0];
for (int n = 0; n < batch_size; ++n) {
std::map<int, std::vector<Box>> boxes;
for (int i = label_index[n]; i < label_index[n + 1]; ++i) {
Box box(labels(i, 2), labels(i, 3), labels(i, 4), labels(i, 5));
int label = labels(i, 0);
auto is_difficult = labels(i, 1);
if (std::abs(is_difficult - 0.0) < 1e-6)
box.is_difficult = false;
else
box.is_difficult = true;
boxes[label].push_back(box);
}
gt_boxes.push_back(boxes);
}
auto detect_index = detect_lod[0];
for (int n = 0; n < batch_size; ++n) {
std::map<int, std::vector<std::pair<T, Box>>> boxes;
for (int i = detect_index[n]; i < detect_index[n + 1]; ++i) {
Box box(detect(i, 2), detect(i, 3), detect(i, 4), detect(i, 5));
int label = detect(i, 0);
auto score = detect(i, 1);
boxes[label].push_back(std::make_pair(score, box));
}
detect_boxes.push_back(boxes);
}
}
void GetOutputPos(
const framework::ExecutionContext& ctx,
const std::map<int, int>& label_pos_count,
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
framework::Tensor& output_pos_count,
framework::LoDTensor& output_true_pos,
framework::LoDTensor& output_false_pos) const {
int max_class_id = 0;
int true_pos_count = 0;
int false_pos_count = 0;
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first;
if (label > max_class_id) max_class_id = label;
int label_num_pos = it->second;
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end())
continue;
auto label_true_pos = true_pos.find(label)->second;
auto label_false_pos = false_pos.find(label)->second;
true_pos_count += label_true_pos.size();
false_pos_count += label_false_pos.size();
}
int* pos_count_data = output_pos_count.mutable_data<int>(
framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace());
T* true_pos_data = output_true_pos.mutable_data<T>(
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
T* false_pos_data = output_false_pos.mutable_data<T>(
framework::make_ddim({false_pos_count, 2}), ctx.GetPlace());
true_pos_count = 0;
false_pos_count = 0;
std::vector<size_t> true_pos_starts = {0};
std::vector<size_t> false_pos_starts = {0};
for (int i = 0; i <= max_class_id; ++i) {
auto it_count = label_pos_count.find(i);
pos_count_data[i] = 0;
if (it_count != label_pos_count.end()) {
pos_count_data[i] = it_count->second;
}
auto it_true_pos = true_pos.find(i);
if (it_true_pos != true_pos.end()) {
const std::vector<std::pair<T, int>>& true_pos_vec =
it_true_pos->second;
for (const std::pair<T, int>& tp : true_pos_vec) {
true_pos_data[true_pos_count * 2] = tp.first;
true_pos_data[true_pos_count * 2 + 1] = static_cast<T>(tp.second);
true_pos_count++;
}
}
true_pos_starts.push_back(true_pos_count);
auto it_false_pos = false_pos.find(i);
if (it_false_pos != false_pos.end()) {
const std::vector<std::pair<T, int>>& false_pos_vec =
it_false_pos->second;
for (const std::pair<T, int>& fp : false_pos_vec) {
false_pos_data[false_pos_count * 2] = fp.first;
false_pos_data[false_pos_count * 2 + 1] = static_cast<T>(fp.second);
false_pos_count++;
}
}
false_pos_starts.push_back(false_pos_count);
}
framework::LoD true_pos_lod;
true_pos_lod.emplace_back(true_pos_starts);
framework::LoD false_pos_lod;
false_pos_lod.emplace_back(false_pos_starts);
output_true_pos.set_lod(true_pos_lod);
output_false_pos.set_lod(false_pos_lod);
return;
}
void GetInputPos(
const framework::Tensor& input_pos_count,
const framework::LoDTensor& input_true_pos,
const framework::LoDTensor& input_false_pos,
std::map<int, int>& label_pos_count,
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
constexpr T kEPS = static_cast<T>(1e-6);
int class_number = input_pos_count.dims()[0];
const int* pos_count_data = input_pos_count.data<int>();
for (int i = 0; i < class_number; ++i) {
label_pos_count[i] = pos_count_data[i];
}
auto SetData = [](const framework::LoDTensor& pos_tensor,
std::map<int, std::vector<std::pair<T, int>>>& pos) {
const T* pos_data = pos_tensor.data<T>();
auto pos_data_lod = pos_tensor.lod();
for (int i = 0; i < pos_data_lod.size(); ++i) {
for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
T score = pos_data[j * 2];
int flag = 1;
if (pos_data[j * 2 + 1] < kEPS) flag = 0;
pos[i].push_back(std::make_pair(score, flag));
}
}
};
SetData(input_true_pos, true_pos);
SetData(input_false_pos, false_pos);
return;
}
void CalcTrueAndFalsePositive(
const std::vector<std::map<int, std::vector<Box>>>& gt_boxes,
const std::vector<std::map<int, std::vector<std::pair<T, Box>>>>&
detect_boxes,
bool evaluate_difficult, float overlap_threshold,
std::map<int, int>& label_pos_count,
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
int batch_size = gt_boxes.size();
for (int n = 0; n < batch_size; ++n) {
auto image_gt_boxes = gt_boxes[n];
for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) {
size_t count = 0;
auto labeled_bboxes = it->second;
if (evaluate_difficult) {
count = labeled_bboxes.size();
} else {
for (size_t i = 0; i < labeled_bboxes.size(); ++i)
if (!(labeled_bboxes[i].is_difficult)) ++count;
}
if (count == 0) {
continue;
}
int label = it->first;
if (label_pos_count.find(label) == label_pos_count.end()) {
label_pos_count[label] = count;
} else {
label_pos_count[label] += count;
}
}
}
for (size_t n = 0; n < detect_boxes.size(); ++n) {
auto image_gt_boxes = gt_boxes[n];
auto detections = detect_boxes[n];
if (image_gt_boxes.size() == 0) {
for (auto it = detections.begin(); it != detections.end(); ++it) {
auto pred_boxes = it->second;
int label = it->first;
for (size_t i = 0; i < pred_boxes.size(); ++i) {
auto score = pred_boxes[i].first;
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
}
continue;
}
for (auto it = detections.begin(); it != detections.end(); ++it) {
int label = it->first;
auto pred_boxes = it->second;
if (image_gt_boxes.find(label) == image_gt_boxes.end()) {
for (size_t i = 0; i < pred_boxes.size(); ++i) {
auto score = pred_boxes[i].first;
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
continue;
}
auto matched_bboxes = image_gt_boxes.find(label)->second;
std::vector<bool> visited(matched_bboxes.size(), false);
// Sort detections in descend order based on scores
std::sort(pred_boxes.begin(), pred_boxes.end(),
SortScorePairDescend<Box>);
for (size_t i = 0; i < pred_boxes.size(); ++i) {
T max_overlap = -1.0;
size_t max_idx = 0;
auto score = pred_boxes[i].first;
for (size_t j = 0; j < matched_bboxes.size(); ++j) {
T overlap = JaccardOverlap(pred_boxes[i].second, matched_bboxes[j]);
if (overlap > max_overlap) {
max_overlap = overlap;
max_idx = j;
}
}
if (max_overlap > overlap_threshold) {
bool match_evaluate_difficult =
evaluate_difficult ||
(!evaluate_difficult && !matched_bboxes[max_idx].is_difficult);
if (match_evaluate_difficult) {
if (!visited[max_idx]) {
true_pos[label].push_back(std::make_pair(score, 1));
false_pos[label].push_back(std::make_pair(score, 0));
visited[max_idx] = true;
} else {
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
}
} else {
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
}
}
}
}
T CalcMAP(
APType ap_type, const std::map<int, int>& label_pos_count,
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
const std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
T mAP = 0.0;
int count = 0;
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first;
int label_num_pos = it->second;
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end())
continue;
auto label_true_pos = true_pos.find(label)->second;
auto label_false_pos = false_pos.find(label)->second;
// Compute average precision.
std::vector<int> tp_sum;
GetAccumulation<T>(label_true_pos, &tp_sum);
std::vector<int> fp_sum;
GetAccumulation<T>(label_false_pos, &fp_sum);
std::vector<T> precision, recall;
size_t num = tp_sum.size();
// Compute Precision.
for (size_t i = 0; i < num; ++i) {
precision.push_back(static_cast<T>(tp_sum[i]) /
static_cast<T>(tp_sum[i] + fp_sum[i]));
recall.push_back(static_cast<T>(tp_sum[i]) / label_num_pos);
}
// VOC2007 style
if (ap_type == APType::k11point) {
std::vector<T> max_precisions(11, 0.0);
int start_idx = num - 1;
for (int j = 10; j >= 0; --j)
for (int i = start_idx; i >= 0; --i) {
if (recall[i] < j / 10.) {
start_idx = i;
if (j > 0) max_precisions[j - 1] = max_precisions[j];
break;
} else {
if (max_precisions[j] < precision[i])
max_precisions[j] = precision[i];
}
}
for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11;
++count;
} else if (ap_type == APType::kIntegral) {
// Nature integral
float average_precisions = 0.;
float prev_recall = 0.;
for (size_t i = 0; i < num; ++i) {
if (fabs(recall[i] - prev_recall) > 1e-6)
average_precisions += precision[i] * fabs(recall[i] - prev_recall);
prev_recall = recall[i];
}
mAP += average_precisions;
++count;
} else {
LOG(FATAL) << "Unkown ap version: " << ap_type;
}
}
if (count != 0) mAP /= count;
return mAP * 100;
}
}; // namespace operators
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
import collections
import math
from op_test import OpTest
class TestDetectionMAPOp(OpTest):
def set_data(self):
self.init_test_case()
self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)]
self.label = np.array(self.label).astype('float32')
self.detect = np.array(self.detect).astype('float32')
self.mAP = np.array(self.mAP).astype('float32')
if (len(self.class_pos_count) > 0):
self.class_pos_count = np.array(self.class_pos_count).astype(
'int32')
self.true_pos = np.array(self.true_pos).astype('float32')
self.false_pos = np.array(self.false_pos).astype('float32')
self.inputs = {
'Label': (self.label, self.label_lod),
'DetectRes': (self.detect, self.detect_lod),
'PosCount': self.class_pos_count,
'TruePos': (self.true_pos, self.true_pos_lod),
'FalsePos': (self.false_pos, self.false_pos_lod)
}
else:
self.inputs = {
'Label': (self.label, self.label_lod),
'DetectRes': (self.detect, self.detect_lod),
}
self.attrs = {
'overlap_threshold': self.overlap_threshold,
'evaluate_difficult': self.evaluate_difficult,
'ap_type': self.ap_type
}
self.out_class_pos_count = np.array(self.out_class_pos_count).astype(
'int')
self.out_true_pos = np.array(self.out_true_pos).astype('float32')
self.out_false_pos = np.array(self.out_false_pos).astype('float32')
self.outputs = {
'MAP': self.mAP,
'AccumPosCount': self.out_class_pos_count,
'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod),
'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod)
}
def init_test_case(self):
self.overlap_threshold = 0.3
self.evaluate_difficult = True
self.ap_type = "integral"
self.label_lod = [[0, 2, 4]]
# label difficult xmin ymin xmax ymax
self.label = [[1, 0, 0.1, 0.1, 0.3, 0.3], [1, 1, 0.6, 0.6, 0.8, 0.8],
[2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]]
# label score xmin ymin xmax ymax difficult
self.detect_lod = [[0, 3, 7]]
self.detect = [
[1, 0.3, 0.1, 0.0, 0.4, 0.3], [1, 0.7, 0.0, 0.1, 0.2, 0.3],
[1, 0.9, 0.7, 0.6, 0.8, 0.8], [2, 0.8, 0.2, 0.1, 0.4, 0.4],
[2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 0.2, 0.8, 0.1, 1.0, 0.3],
[3, 0.2, 0.8, 0.1, 1.0, 0.3]
]
# label score true_pos false_pos
self.tf_pos_lod = [[0, 3, 7]]
self.tf_pos = [[1, 0.9, 1, 0], [1, 0.7, 1, 0], [1, 0.3, 0, 1],
[1, 0.2, 1, 0], [2, 0.8, 0, 1], [2, 0.1, 1, 0],
[3, 0.2, 0, 1]]
self.class_pos_count = []
self.true_pos_lod = [[]]
self.true_pos = [[]]
self.false_pos_lod = [[]]
self.false_pos = [[]]
def calc_map(self, tf_pos, tf_pos_lod):
mAP = 0.0
count = 0
def get_input_pos(class_pos_count, true_pos, true_pos_lod, false_pos,
false_pos_lod):
class_pos_count_dict = collections.Counter()
true_pos_dict = collections.defaultdict(list)
false_pos_dict = collections.defaultdict(list)
for i, count in enumerate(class_pos_count):
class_pos_count_dict[i] = count
for i in range(len(true_pos_lod[0]) - 1):
start = true_pos_lod[0][i]
end = true_pos_lod[0][i + 1]
for j in range(start, end):
true_pos_dict[i].append(true_pos[j])
for i in range(len(false_pos_lod[0]) - 1):
start = false_pos_lod[0][i]
end = false_pos_lod[0][i + 1]
for j in range(start, end):
false_pos_dict[i].append(false_pos[j])
return class_pos_count_dict, true_pos_dict, false_pos_dict
def get_output_pos(label_count, true_pos, false_pos):
max_label = 0
for (label, label_pos_num) in label_count.items():
if max_label < label:
max_label = label
label_number = max_label + 1
out_class_pos_count = []
out_true_pos_lod = [0]
out_true_pos = []
out_false_pos_lod = [0]
out_false_pos = []
for i in range(label_number):
out_class_pos_count.append([label_count[i]])
true_pos_list = true_pos[i]
out_true_pos += true_pos_list
out_true_pos_lod.append(len(out_true_pos))
false_pos_list = false_pos[i]
out_false_pos += false_pos_list
out_false_pos_lod.append(len(out_false_pos))
return out_class_pos_count, out_true_pos, [
out_true_pos_lod
], out_false_pos, [out_false_pos_lod]
def get_accumulation(pos_list):
sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True)
sum = 0
accu_list = []
for (score, count) in sorted_list:
sum += count
accu_list.append(sum)
return accu_list
label_count, true_pos, false_pos = get_input_pos(
self.class_pos_count, self.true_pos, self.true_pos_lod,
self.false_pos, self.false_pos_lod)
for (label, difficult, xmin, ymin, xmax, ymax) in self.label:
if self.evaluate_difficult:
label_count[label] += 1
elif not difficult:
label_count[label] += 1
true_pos = collections.defaultdict(list)
false_pos = collections.defaultdict(list)
for (label, score, tp, fp) in tf_pos:
true_pos[label].append([score, tp])
false_pos[label].append([score, fp])
for (label, label_pos_num) in label_count.items():
if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
accu_tp_sum = get_accumulation(label_true_pos)
accu_fp_sum = get_accumulation(label_false_pos)
precision = []
recall = []
for i in range(len(accu_tp_sum)):
precision.append(
float(accu_tp_sum[i]) /
float(accu_tp_sum[i] + accu_fp_sum[i]))
recall.append(float(accu_tp_sum[i]) / label_pos_num)
if self.ap_type == "11point":
max_precisions = [0.0] * 11
start_idx = len(accu_tp_sum) - 1
for j in range(10, -1, -1):
for i in range(start_idx, -1, -1):
if recall[i] < float(j) / 10.0:
start_idx = i
if j > 0:
max_precisions[j - 1] = max_precisions[j]
break
else:
if max_precisions[j] < precision[i]:
max_precisions[j] = precision[i]
for j in range(10, -1, -1):
mAP += max_precisions[j] / 11
count += 1
elif self.ap_type == "integral":
average_precisions = 0.0
prev_recall = 0.0
for i in range(len(accu_tp_sum)):
if math.fabs(recall[i] - prev_recall) > 1e-6:
average_precisions += precision[i] * \
math.fabs(recall[i] - prev_recall)
prev_recall = recall[i]
mAP += average_precisions
count += 1
self.out_class_pos_count, self.out_true_pos, self.out_true_pos_lod, self.out_false_pos, self.out_false_pos_lod = get_output_pos(
label_count, true_pos, false_pos)
if count != 0:
mAP /= count
return mAP * 100.0
def setUp(self):
self.op_type = "detection_map"
self.set_data()
def test_check_output(self):
self.check_output()
class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp):
def init_test_case(self):
super(TestDetectionMAPOpSkipDiff, self).init_test_case()
self.evaluate_difficult = False
self.tf_pos_lod = [[0, 2, 6]]
# label score true_pos false_pos
self.tf_pos = [[1, 0.7, 1, 0], [1, 0.3, 0, 1], [1, 0.2, 1, 0],
[2, 0.8, 0, 1], [2, 0.1, 1, 0], [3, 0.2, 0, 1]]
class TestDetectionMAPOp11Point(TestDetectionMAPOp):
def init_test_case(self):
super(TestDetectionMAPOp11Point, self).init_test_case()
self.ap_type = "11point"
class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp):
def init_test_case(self):
super(TestDetectionMAPOpMultiBatch, self).init_test_case()
self.class_pos_count = [0, 2, 1]
self.true_pos_lod = [[0, 0, 3, 5]]
self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]]
self.false_pos_lod = [[0, 0, 3, 5]]
self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册