diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index efb06d60a2fd4bc663b6690325baaed1797d23b5..7f0fc170f4788634d4791a5b124c4f916e978785 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -362,6 +362,7 @@ paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'ancho paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f332fb8c5bb581bd1a6b5be450a99990')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '04384378ff00a42ade8fabd52e27cbc5')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) +paddle.fluid.layers.retinanet_detection_output (ArgSpec(args=['bboxes', 'scores', 'anchors', 'im_info', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0.05, 1000, 100, 0.3, 1.0)), ('document', '078d28607ce261a0cba2b965a79f6bb8')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5')) paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace')) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 059106992439ba46af843533e2ce073b74cd7779..f1c504d6e4bd065e4221b1207a117ff0f6732459 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -36,6 +36,7 @@ detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) +detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a6dfec12e660431844682694632a3b18d91bf3e --- /dev/null +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -0,0 +1,566 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE( + ctx->Inputs("BBoxes").size(), 1UL, + "Input(BBoxes) of RetinanetDetectionOutput should not be null."); + PADDLE_ENFORCE_GE( + ctx->Inputs("Scores").size(), 1UL, + "Input(Scores) of RetinanetDetectionOutput should not be null."); + PADDLE_ENFORCE_GE( + ctx->Inputs("Anchors").size(), 1UL, + "Input(Anchors) of RetinanetDetectionOutput should not be null."); + PADDLE_ENFORCE_EQ( + ctx->Inputs("BBoxes").size(), ctx->Inputs("Scores").size(), + "Input tensors(BBoxes and Scores) should have the same size."); + PADDLE_ENFORCE_EQ( + ctx->Inputs("BBoxes").size(), ctx->Inputs("Anchors").size(), + "Input tensors(BBoxes and Anchors) should have the same size."); + PADDLE_ENFORCE( + ctx->HasInput("ImInfo"), + "Input(ImInfo) of RetinanetDetectionOutput should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of RetinanetDetectionOutput should not be null."); + + auto bboxes_dims = ctx->GetInputsDim("BBoxes"); + auto scores_dims = ctx->GetInputsDim("Scores"); + auto anchors_dims = ctx->GetInputsDim("Anchors"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + + const size_t b_n = bboxes_dims.size(); + PADDLE_ENFORCE_GT(b_n, 0, "Input bbox tensors count should > 0."); + const size_t s_n = scores_dims.size(); + PADDLE_ENFORCE_GT(s_n, 0, "Input score tensors count should > 0."); + const size_t a_n = anchors_dims.size(); + PADDLE_ENFORCE_GT(a_n, 0, "Input anchor tensors count should > 0."); + + auto bbox_dims = bboxes_dims[0]; + auto score_dims = scores_dims[0]; + auto anchor_dims = anchors_dims[0]; + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(score_dims.size(), 3, + "The rank of Input(Scores) must be 3"); + PADDLE_ENFORCE_EQ(bbox_dims.size(), 3, + "The rank of Input(BBoxes) must be 3"); + PADDLE_ENFORCE_EQ(anchor_dims.size(), 2, + "The rank of Input(Anchors) must be 2"); + PADDLE_ENFORCE(bbox_dims[2] == 4, + "The last dimension of Input(BBoxes) must be 4, " + "represents the layout of coordinate " + "[xmin, ymin, xmax, ymax]"); + PADDLE_ENFORCE_EQ(bbox_dims[1], score_dims[1], + "The 2nd dimension of Input(BBoxes) must be equal to " + "2nd dimension of Input(Scores), which represents the " + "number of the predicted boxes."); + + PADDLE_ENFORCE_EQ(anchor_dims[0], bbox_dims[1], + "The 1st dimension of Input(Anchors) must be equal to " + "2nd dimension of Input(BBoxes), which represents the " + "number of the predicted boxes."); + PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, + "The rank of Input(ImInfo) must be 2."); + } + // Here the box_dims[0] is not the real dimension of output. + // It will be rewritten in the computing kernel. + ctx->SetOutputDim("Out", {bbox_dims[1], bbox_dims[2] + 2}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]); + + return framework::OpKernelType(input_data_type, + platform::CPUPlace()); // ctx.GetPlace()); + } +}; + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +bool SortScoreTwoPairDescend(const std::pair>& pair1, + const std::pair>& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, const T threshold, int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +template +static inline T BBoxArea(const std::vector& box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const std::vector& box1, + const std::vector& box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +class RetinanetDetectionOutputKernel : public framework::OpKernel { + public: + void NMSFast(const std::vector>& cls_dets, + const T nms_threshold, const T eta, + std::vector* selected_indices) const { + int64_t num_boxes = cls_dets.size(); + std::vector> sorted_indices; + for (int64_t i = 0; i < num_boxes; ++i) { + sorted_indices.push_back(std::make_pair(cls_dets[i][4], i)); + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices.begin(), sorted_indices.end(), + SortScorePairDescend); + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = T(0.); + + overlap = JaccardOverlap(cls_dets[idx], cls_dets[kept_idx], false); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } + } + + void DeltaScoreToPrediction( + const std::vector& bboxes_data, const std::vector& anchors_data, + T im_height, T im_width, T im_scale, int class_num, + const std::vector>& sorted_indices, + std::map>>* preds) const { + im_height = static_cast(round(im_height / im_scale)); + im_width = static_cast(round(im_width / im_scale)); + T zero(0); + int i = 0; + for (const auto& it : sorted_indices) { + T score = it.first; + int idx = it.second; + int a = idx / class_num; + int c = idx % class_num; + + int box_offset = a * 4; + T anchor_box_width = + anchors_data[box_offset + 2] - anchors_data[box_offset] + 1; + T anchor_box_height = + anchors_data[box_offset + 3] - anchors_data[box_offset + 1] + 1; + T anchor_box_center_x = anchors_data[box_offset] + anchor_box_width / 2; + T anchor_box_center_y = + anchors_data[box_offset + 1] + anchor_box_height / 2; + T target_box_center_x = 0, target_box_center_y = 0; + T target_box_width = 0, target_box_height = 0; + target_box_center_x = + bboxes_data[box_offset] * anchor_box_width + anchor_box_center_x; + target_box_center_y = + bboxes_data[box_offset + 1] * anchor_box_height + anchor_box_center_y; + target_box_width = + std::exp(bboxes_data[box_offset + 2]) * anchor_box_width; + target_box_height = + std::exp(bboxes_data[box_offset + 3]) * anchor_box_height; + T pred_box_xmin = target_box_center_x - target_box_width / 2; + T pred_box_ymin = target_box_center_y - target_box_height / 2; + T pred_box_xmax = target_box_center_x + target_box_width / 2 - 1; + T pred_box_ymax = target_box_center_y + target_box_height / 2 - 1; + pred_box_xmin = pred_box_xmin / im_scale; + pred_box_ymin = pred_box_ymin / im_scale; + pred_box_xmax = pred_box_xmax / im_scale; + pred_box_ymax = pred_box_ymax / im_scale; + + pred_box_xmin = std::max(std::min(pred_box_xmin, im_width - 1), zero); + pred_box_ymin = std::max(std::min(pred_box_ymin, im_height - 1), zero); + pred_box_xmax = std::max(std::min(pred_box_xmax, im_width - 1), zero); + pred_box_ymax = std::max(std::min(pred_box_ymax, im_height - 1), zero); + + std::vector one_pred; + one_pred.push_back(pred_box_xmin); + one_pred.push_back(pred_box_ymin); + one_pred.push_back(pred_box_xmax); + one_pred.push_back(pred_box_ymax); + one_pred.push_back(score); + (*preds)[c].push_back(one_pred); + i++; + } + } + + void MultiClassNMS(const std::map>>& preds, + int class_num, const int keep_top_k, const T nms_threshold, + const T nms_eta, std::vector>* nmsed_out, + int* num_nmsed_out) const { + std::map> indices; + int num_det = 0; + for (int c = 0; c < class_num; ++c) { + if (static_cast(preds.count(c))) { + const std::vector> cls_dets = preds.at(c); + NMSFast(cls_dets, nms_threshold, nms_eta, &(indices[c])); + num_det += indices[c].size(); + } + } + + std::vector>> score_index_pairs; + for (const auto& it : indices) { + int label = it.first; + const std::vector& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + score_index_pairs.push_back(std::make_pair(preds.at(label)[idx][4], + std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScoreTwoPairDescend); + if (num_det > keep_top_k) { + score_index_pairs.resize(keep_top_k); + } + + // Store the new indices. + std::map> new_indices; + for (const auto& it : score_index_pairs) { + int label = it.second.first; + int idx = it.second.second; + std::vector one_pred; + one_pred.push_back(label); + one_pred.push_back(preds.at(label)[idx][4]); + one_pred.push_back(preds.at(label)[idx][0]); + one_pred.push_back(preds.at(label)[idx][1]); + one_pred.push_back(preds.at(label)[idx][2]); + one_pred.push_back(preds.at(label)[idx][3]); + nmsed_out->push_back(one_pred); + } + + *num_nmsed_out = (num_det > keep_top_k ? keep_top_k : num_det); + } + + void RetinanetDetectionOutput(const framework::ExecutionContext& ctx, + const std::vector& scores, + const std::vector& bboxes, + const std::vector& anchors, + const Tensor& im_info, + std::vector>* nmsed_out, + int* num_nmsed_out) const { + int64_t nms_top_k = ctx.Attr("nms_top_k"); + int64_t keep_top_k = ctx.Attr("keep_top_k"); + T nms_threshold = static_cast(ctx.Attr("nms_threshold")); + T nms_eta = static_cast(ctx.Attr("nms_eta")); + T score_threshold = static_cast(ctx.Attr("score_threshold")); + + int64_t class_num = scores[0].dims()[1]; + std::map>> preds; + for (size_t l = 0; l < scores.size(); ++l) { + // Fetch per level score + Tensor scores_per_level = scores[l]; + // Fetch per level bbox + Tensor bboxes_per_level = bboxes[l]; + // Fetch per level anchor + Tensor anchors_per_level = anchors[l]; + + int64_t scores_num = scores_per_level.numel(); + int64_t bboxes_num = bboxes_per_level.numel(); + std::vector scores_data(scores_num); + std::vector bboxes_data(bboxes_num); + std::vector anchors_data(bboxes_num); + std::copy_n(scores_per_level.data(), scores_num, scores_data.begin()); + std::copy_n(bboxes_per_level.data(), bboxes_num, bboxes_data.begin()); + std::copy_n(anchors_per_level.data(), bboxes_num, + anchors_data.begin()); + std::vector> sorted_indices; + + // For the highest level, we take the threshold 0.0 + T threshold = (l < (scores.size() - 1) ? score_threshold : 0.0); + GetMaxScoreIndex(scores_data, threshold, nms_top_k, &sorted_indices); + auto* im_info_data = im_info.data(); + auto im_height = im_info_data[0]; + auto im_width = im_info_data[1]; + auto im_scale = im_info_data[2]; + DeltaScoreToPrediction(bboxes_data, anchors_data, im_height, im_width, + im_scale, class_num, sorted_indices, &preds); + } + + MultiClassNMS(preds, class_num, keep_top_k, nms_threshold, nms_eta, + nmsed_out, num_nmsed_out); + } + + void MultiClassOutput(const platform::DeviceContext& ctx, + const std::vector>& nmsed_out, + Tensor* outs) const { + auto* odata = outs->data(); + int count = 0; + int64_t out_dim = 6; + for (size_t i = 0; i < nmsed_out.size(); ++i) { + odata[count * out_dim] = nmsed_out[i][0] + 1; // label + odata[count * out_dim + 1] = nmsed_out[i][1]; // score + odata[count * out_dim + 2] = nmsed_out[i][2]; // xmin + odata[count * out_dim + 3] = nmsed_out[i][3]; // xmin + odata[count * out_dim + 4] = nmsed_out[i][4]; // xmin + odata[count * out_dim + 5] = nmsed_out[i][5]; // xmin + count++; + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto boxes = ctx.MultiInput("BBoxes"); + auto scores = ctx.MultiInput("Scores"); + auto anchors = ctx.MultiInput("Anchors"); + auto* im_info = ctx.Input("ImInfo"); + auto* outs = ctx.Output("Out"); + + std::vector boxes_list(boxes.size()); + std::vector scores_list(scores.size()); + std::vector anchors_list(anchors.size()); + for (size_t j = 0; j < boxes_list.size(); ++j) { + boxes_list[j] = *boxes[j]; + scores_list[j] = *scores[j]; + anchors_list[j] = *anchors[j]; + } + auto score_dims = scores_list[0].dims(); + int64_t batch_size = score_dims[0]; + auto box_dims = boxes_list[0].dims(); + int64_t box_dim = box_dims[2]; + int64_t out_dim = box_dim + 2; + + auto& dev_ctx = ctx.template device_context(); + + std::vector>> all_nmsed_out; + std::vector batch_starts = {0}; + for (int i = 0; i < batch_size; ++i) { + int num_nmsed_out = 0; + std::vector box_per_batch_list(boxes_list.size()); + std::vector score_per_batch_list(scores_list.size()); + for (size_t j = 0; j < boxes_list.size(); ++j) { + auto score_dims = scores_list[j].dims(); + score_per_batch_list[j] = scores_list[j].Slice(i, i + 1); + score_per_batch_list[j].Resize({score_dims[1], score_dims[2]}); + box_per_batch_list[j] = boxes_list[j].Slice(i, i + 1); + box_per_batch_list[j].Resize({score_dims[1], box_dim}); + } + Tensor im_info_slice = im_info->Slice(i, i + 1); + + std::vector> nmsed_out; + RetinanetDetectionOutput(ctx, score_per_batch_list, box_per_batch_list, + anchors_list, im_info_slice, &nmsed_out, + &num_nmsed_out); + all_nmsed_out.push_back(nmsed_out); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + int num_kept = batch_starts.back(); + if (num_kept == 0) { + outs->Resize({0, out_dim}); + } else { + outs->mutable_data({num_kept, out_dim}, ctx.GetPlace()); + for (int i = 0; i < batch_size; ++i) { + int64_t s = batch_starts[i]; + int64_t e = batch_starts[i + 1]; + if (e > s) { + Tensor out = outs->Slice(s, e); + MultiClassOutput(dev_ctx, all_nmsed_out[i], &out); + } + } + } + + framework::LoD lod; + lod.emplace_back(batch_starts); + + outs->set_lod(lod); + } +}; + +class RetinanetDetectionOutputOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("BBoxes", + "(List) A list of tensors from multiple FPN levels. Each " + "element is a 3-D Tensor with shape [N, Mi, 4] represents the " + "predicted locations of Mi bounding boxes, N is the batch size. " + "Mi is the number of bounding boxes from i-th FPN level. Each " + "bounding box has four coordinate values and the layout is " + "[xmin, ymin, xmax, ymax].") + .AsDuplicable(); + AddInput("Scores", + "(List) A list of tensors from multiple FPN levels. Each " + "element is a 3-D Tensor with shape [N, Mi, C] represents the " + "predicted confidence from its FPN level. N is the batch size, " + "C is the class number (excluding background), Mi is the number " + "of bounding boxes from i-th FPN level. For each bounding box, " + "there are total C scores.") + .AsDuplicable(); + AddInput("Anchors", + "(List) A list of tensors from multiple FPN levels. Each" + "element is a 2-D Tensor with shape [Mi, 4] represents the " + "locations of Mi anchor boxes from i-th FPN level. Each " + "bounding box has four coordinate values and the layout is " + "[xmin, ymin, xmax, ymax].") + .AsDuplicable(); + AddInput("ImInfo", + "(LoDTensor) A 2-D LoDTensor with shape [N, 3] represents the " + "image information. N is the batch size, each image information " + "includes height, width and scale."); + AddAttr("score_threshold", + "(float) " + "Threshold to filter out bounding boxes with a confidence " + "score."); + AddAttr("nms_top_k", + "(int64_t) " + "Maximum number of detections per FPN layer to be kept " + "according to the confidence before NMS."); + AddAttr("nms_threshold", + "(float) " + "The threshold to be used in NMS."); + AddAttr("nms_eta", + "(float) " + "The parameter for adaptive NMS."); + AddAttr( + "keep_top_k", + "(int64_t) " + "Number of total bounding boxes to be kept per image after NMS " + "step."); + AddOutput("Out", + "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax]" + "No is the total number of detections in this mini-batch." + "For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected bbox."); + AddComment(R"DOC( +This operator is to decode boxes and scores from each FPN layer and do +multi-class non maximum suppression (NMS) on merged predictions. + +Top-scoring predictions per FPN layer are decoded with the anchor +information. This operator greedily selects a subset of detection bounding +boxes from each FPN layer that have high scores larger than score_threshold, +if providing this threshold, then selects the largest nms_top_k confidences +scores per FPN layer, if nms_top_k is larger than -1. +The decoding schema is described below: + +ox = (pw * pxv * tx * + px) - tw / 2 + +oy = (ph * pyv * ty * + py) - th / 2 + +ow = exp(pwv * tw) * pw + tw / 2 + +oh = exp(phv * th) * ph + th / 2 + +where `tx`, `ty`, `tw`, `th` denote the predicted box's center coordinates, width +and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the +anchor's center coordinates, width and height. `pxv`, `pyv`, `pwv`, +`phv` denote the variance of the anchor box and `ox`, `oy`, `ow`, `oh` denote the +decoded coordinates, width and height. + +Then the top decoded prediction from all levels are merged followed by NMS. +In the NMS step, this operator prunes away boxes that have high IOU +(intersection over union) overlap with already selected boxes by adaptive +threshold NMS based on parameters of nms_threshold and nms_eta. +After NMS step, at most keep_top_k number of total bounding boxes are to be kept +per image if keep_top_k is larger than -1. +This operator support multi-class and batched inputs. It applying NMS +independently for each class. The outputs is a 2-D LoDTenosr, for each +image, the offsets in first dimension of LoDTensor are called LoD, the number +of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0, +means there is no detected bounding box for this image. If there is no detected boxes +for all images, all the elements in LoD are set to 0, and the output tensor is +empty (None). +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(retinanet_detection_output, ops::RetinanetDetectionOutputOp, + ops::RetinanetDetectionOutputOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(retinanet_detection_output, + ops::RetinanetDetectionOutputKernel, + ops::RetinanetDetectionOutputKernel); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 7b3a7faaddef94fb40d160a16711d649b4bc472e..36877269faa0b636a672454b3d682b89a5b94a30 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -53,6 +53,7 @@ __all__ = [ 'yolo_box', 'box_clip', 'multiclass_nms', + 'retinanet_detection_output', 'distribute_fpn_proposals', 'box_decoder_and_assign', 'collect_fpn_proposals', @@ -2548,6 +2549,113 @@ def box_clip(input, im_info, name=None): return output +def retinanet_detection_output(bboxes, + scores, + anchors, + im_info, + score_threshold=0.05, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.3, + nms_eta=1.): + """ + **Detection Output Layer for Retinanet.** + + This operation is to get the detection results by performing following + steps: + + 1. Decode top-scoring bounding box predictions per FPN level according + to the anchor boxes. + 2. Merge top predictions from all levels and apply multi-class non + maximum suppression (NMS) on them to get the final detections. + + Args: + bboxes(List): A list of tensors from multiple FPN levels. Each + element is a 3-D Tensor with shape [N, Mi, 4] representing the + predicted locations of Mi bounding boxes. N is the batch size, + Mi is the number of bounding boxes from i-th FPN level and each + bounding box has four coordinate values and the layout is + [xmin, ymin, xmax, ymax]. + scores(List): A list of tensors from multiple FPN levels. Each + element is a 3-D Tensor with shape [N, Mi, C] representing the + predicted confidence predictions. N is the batch size, C is the + class number (excluding background), Mi is the number of bounding + boxes from i-th FPN level. For each bounding box, there are total + C scores. + anchors(List): A 2-D Tensor with shape [Mi, 4] represents the locations + of Mi anchor boxes from all FPN level. Each bounding box has four + coordinate values and the layout is [xmin, ymin, xmax, ymax]. + im_info(Variable): A 2-D LoDTensor with shape [N, 3] represents the + image information. N is the batch size, each image information + includes height, width and scale. + score_threshold(float): Threshold to filter out bounding boxes + with a confidence score. + nms_top_k(int): Maximum number of detections per FPN layer to be + kept according to the confidences before NMS. + keep_top_k(int): Number of total bounding boxes to be kept per image after + NMS step. -1 means keeping all bounding boxes after NMS step. + nms_threshold(float): The threshold to be used in NMS. + nms_eta(float): The parameter for adaptive NMS. + + Returns: + Variable: + The detection output is a LoDTensor with shape [No, 6]. + Each row has six values: [label, confidence, xmin, ymin, xmax, ymax]. + `No` is the total number of detections in this mini-batch. For each + instance, the offsets in first dimension are called LoD, the offset + number is N + 1, N is the batch size. The i-th image has + `LoD[i + 1] - LoD[i]` detected results, if it is 0, the i-th image + has no detected results. If all images have no detected results, + LoD will be set to 0, and the output tensor is empty (None). + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + bboxes = layers.data(name='bboxes', shape=[1, 21, 4], + append_batch_size=False, dtype='float32') + scores = layers.data(name='scores', shape=[1, 21, 10], + append_batch_size=False, dtype='float32') + anchors = layers.data(name='anchors', shape=[21, 4], + append_batch_size=False, dtype='float32') + im_info = layers.data(name="im_info", shape=[1, 3], + append_batch_size=False, dtype='float32') + nmsed_outs = fluid.layers.retinanet_detection_output( + bboxes=[bboxes, bboxes], + scores=[scores, scores], + anchors=[anchors, anchors], + im_info=im_info, + score_threshold=0.05, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.3, + nms_eta=1.) + """ + + helper = LayerHelper('retinanet_detection_output', **locals()) + output = helper.create_variable_for_type_inference( + dtype=helper.input_dtype('scores')) + helper.append_op( + type="retinanet_detection_output", + inputs={ + 'BBoxes': bboxes, + 'Scores': scores, + 'Anchors': anchors, + 'ImInfo': im_info + }, + attrs={ + 'score_threshold': score_threshold, + 'nms_top_k': nms_top_k, + 'nms_threshold': nms_threshold, + 'keep_top_k': keep_top_k, + 'nms_eta': 1., + }, + outputs={'Out': output}) + output.stop_gradient = True + return output + + def multiclass_nms(bboxes, scores, score_threshold, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8c2b4259ea9be8aa6c3328825de7f86e6d7eced1..944b1bb12fe20486777972caffc4d69faebb5bea 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2093,6 +2093,41 @@ class TestBook(LayerTest): x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25) return (out) + def test_retinanet_detection_output(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + bboxes = layers.data( + name='bboxes', + shape=[1, 21, 4], + append_batch_size=False, + dtype='float32') + scores = layers.data( + name='scores', + shape=[1, 21, 10], + append_batch_size=False, + dtype='float32') + anchors = layers.data( + name='anchors', + shape=[21, 4], + append_batch_size=False, + dtype='float32') + im_info = layers.data( + name="im_info", + shape=[1, 3], + append_batch_size=False, + dtype='float32') + nmsed_outs = layers.retinanet_detection_output( + bboxes=[bboxes, bboxes], + scores=[scores, scores], + anchors=[anchors, anchors], + im_info=im_info, + score_threshold=0.05, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.3, + nms_eta=1.) + return (nmsed_outs) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py b/python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py new file mode 100644 index 0000000000000000000000000000000000000000..fafc7de33bc2e49dba699bba8466868f8901614d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py @@ -0,0 +1,412 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +#Licensed under the Apache License, Version 2.0 (the "License") +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import math +import copy +from op_test import OpTest +from test_anchor_generator_op import anchor_generator_in_python +from test_multiclass_nms_op import iou +from test_multiclass_nms_op import nms + + +def multiclass_nms(prediction, class_num, keep_top_k, nms_threshold): + selected_indices = {} + num_det = 0 + for c in range(class_num): + if c not in prediction.keys(): + continue + cls_dets = prediction[c] + all_scores = np.zeros(len(cls_dets)) + for i in range(all_scores.shape[0]): + all_scores[i] = cls_dets[i][4] + indices = nms(cls_dets, all_scores, 0.0, nms_threshold, -1, False, 1.0) + selected_indices[c] = indices + num_det += len(indices) + + score_index = [] + for c, indices in selected_indices.items(): + for idx in indices: + score_index.append((prediction[c][idx][4], c, idx)) + + sorted_score_index = sorted( + score_index, key=lambda tup: tup[0], reverse=True) + if keep_top_k > -1 and num_det > keep_top_k: + sorted_score_index = sorted_score_index[:keep_top_k] + num_det = keep_top_k + nmsed_outs = [] + for s, c, idx in sorted_score_index: + xmin = prediction[c][idx][0] + ymin = prediction[c][idx][1] + xmax = prediction[c][idx][2] + ymax = prediction[c][idx][3] + nmsed_outs.append([c + 1, s, xmin, ymin, xmax, ymax]) + + return nmsed_outs, num_det + + +def retinanet_detection_out(boxes_list, scores_list, anchors_list, im_info, + score_threshold, nms_threshold, nms_top_k, + keep_top_k): + class_num = scores_list[0].shape[-1] + im_height, im_width, im_scale = im_info + + num_level = len(scores_list) + prediction = {} + for lvl in range(num_level): + scores_per_level = scores_list[lvl] + scores_per_level = scores_per_level.flatten() + bboxes_per_level = boxes_list[lvl] + bboxes_per_level = bboxes_per_level.flatten() + anchors_per_level = anchors_list[lvl] + anchors_per_level = anchors_per_level.flatten() + + thresh = score_threshold if lvl < (num_level - 1) else 0.0 + selected_indices = np.argwhere(scores_per_level > thresh) + scores = scores_per_level[selected_indices] + sorted_indices = np.argsort(-scores, axis=0, kind='mergesort') + if nms_top_k > -1 and nms_top_k < sorted_indices.shape[0]: + sorted_indices = sorted_indices[:nms_top_k] + + for i in range(sorted_indices.shape[0]): + idx = selected_indices[sorted_indices[i]] + idx = idx[0][0] + a = int(idx / class_num) + c = int(idx % class_num) + box_offset = a * 4 + anchor_box_width = anchors_per_level[ + box_offset + 2] - anchors_per_level[box_offset] + 1 + anchor_box_height = anchors_per_level[ + box_offset + 3] - anchors_per_level[box_offset + 1] + 1 + anchor_box_center_x = anchors_per_level[ + box_offset] + anchor_box_width / 2 + anchor_box_center_y = anchors_per_level[box_offset + + 1] + anchor_box_height / 2 + + target_box_center_x = bboxes_per_level[ + box_offset] * anchor_box_width + anchor_box_center_x + target_box_center_y = bboxes_per_level[ + box_offset + 1] * anchor_box_height + anchor_box_center_y + target_box_width = math.exp(bboxes_per_level[box_offset + + 2]) * anchor_box_width + target_box_height = math.exp(bboxes_per_level[ + box_offset + 3]) * anchor_box_height + + pred_box_xmin = target_box_center_x - target_box_width / 2 + pred_box_ymin = target_box_center_y - target_box_height / 2 + pred_box_xmax = target_box_center_x + target_box_width / 2 - 1 + pred_box_ymax = target_box_center_y + target_box_height / 2 - 1 + + pred_box_xmin = pred_box_xmin / im_scale + pred_box_ymin = pred_box_ymin / im_scale + pred_box_xmax = pred_box_xmax / im_scale + pred_box_ymax = pred_box_ymax / im_scale + + pred_box_xmin = max( + min(pred_box_xmin, np.round(im_width / im_scale) - 1), 0.) + pred_box_ymin = max( + min(pred_box_ymin, np.round(im_height / im_scale) - 1), 0.) + pred_box_xmax = max( + min(pred_box_xmax, np.round(im_width / im_scale) - 1), 0.) + pred_box_ymax = max( + min(pred_box_ymax, np.round(im_height / im_scale) - 1), 0.) + + if c not in prediction.keys(): + prediction[c] = [] + prediction[c].append([ + pred_box_xmin, pred_box_ymin, pred_box_xmax, pred_box_ymax, + scores_per_level[idx] + ]) + + nmsed_outs, nmsed_num = multiclass_nms(prediction, class_num, keep_top_k, + nms_threshold) + return nmsed_outs, nmsed_num + + +def batched_retinanet_detection_out(boxes, scores, anchors, im_info, + score_threshold, nms_threshold, nms_top_k, + keep_top_k): + batch_size = scores[0].shape[0] + det_outs = [] + lod = [] + + for n in range(batch_size): + boxes_per_batch = [] + scores_per_batch = [] + + num_level = len(scores) + for lvl in range(num_level): + boxes_per_batch.append(boxes[lvl][n]) + scores_per_batch.append(scores[lvl][n]) + + nmsed_outs, nmsed_num = retinanet_detection_out( + boxes_per_batch, scores_per_batch, anchors, im_info[n], + score_threshold, nms_threshold, nms_top_k, keep_top_k) + lod.append(nmsed_num) + if nmsed_num == 0: + continue + + det_outs.extend(nmsed_outs) + return det_outs, lod + + +class TestRetinanetDetectionOutOp1(OpTest): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + def init_test_input(self): + anchor_num = len(self.aspect_ratios) * self.scales_per_octave + num_levels = self.max_level - self.min_level + 1 + self.scores_list = [] + self.bboxes_list = [] + self.anchors_list = [] + + for i in range(num_levels): + layer_h = self.layer_h[i] + layer_w = self.layer_w[i] + + input_feat = np.random.random((self.batch_size, self.input_channels, + layer_h, layer_w)).astype('float32') + score = np.random.random( + (self.batch_size, self.class_num * anchor_num, layer_h, + layer_w)).astype('float32') + score = np.transpose(score, [0, 2, 3, 1]) + score = score.reshape((self.batch_size, -1, self.class_num)) + box = np.random.random((self.batch_size, self.box_size * anchor_num, + layer_h, layer_w)).astype('float32') + box = np.transpose(box, [0, 2, 3, 1]) + box = box.reshape((self.batch_size, -1, self.box_size)) + anchor_sizes = [] + for octave in range(self.scales_per_octave): + anchor_sizes.append( + float(self.anchor_strides[i] * (2**octave)) / + float(self.scales_per_octave) * self.anchor_scale) + anchor, var = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=anchor_sizes, + aspect_ratios=self.aspect_ratios, + variances=[1.0, 1.0, 1.0, 1.0], + stride=[self.anchor_strides[i], self.anchor_strides[i]], + offset=0.5) + anchor = np.reshape(anchor, [-1, 4]) + self.scores_list.append(score.astype('float32')) + self.bboxes_list.append(box.astype('float32')) + self.anchors_list.append(anchor.astype('float32')) + + self.im_info = np.array([[256., 256., 1.5]]).astype( + 'float32') #im_height, im_width, scale + + def setUp(self): + self.set_argument() + self.init_test_input() + + nmsed_outs, lod = batched_retinanet_detection_out( + self.bboxes_list, self.scores_list, self.anchors_list, self.im_info, + self.score_threshold, self.nms_threshold, self.nms_top_k, + self.keep_top_k) + nmsed_outs = np.array(nmsed_outs).astype('float32') + self.op_type = 'retinanet_detection_output' + self.inputs = { + 'BBoxes': [('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]), + ('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3]), + ('b4', self.bboxes_list[4])], + 'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]), + ('s2', self.scores_list[2]), ('s3', self.scores_list[3]), + ('s4', self.scores_list[4])], + 'Anchors': + [('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]), + ('a2', self.anchors_list[2]), ('a3', self.anchors_list[3]), + ('a4', self.anchors_list[4])], + 'ImInfo': (self.im_info, [[1, ]]) + } + self.outputs = {'Out': (nmsed_outs, [lod])} + self.attrs = { + 'score_threshold': self.score_threshold, + 'nms_top_k': self.nms_top_k, + 'nms_threshold': self.nms_threshold, + 'keep_top_k': self.keep_top_k, + 'nms_eta': 1., + } + + def test_check_output(self): + self.check_output() + + +class TestRetinanetDetectionOutOp2(OpTest): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + # Here test the case there the shape of each FPN level + # is irrelevant. + self.layer_h = [1, 4, 8, 8, 16] + self.layer_w = [1, 4, 8, 8, 16] + + +class TestRetinanetDetectionOutOpNo3(TestRetinanetDetectionOutOp1): + def set_argument(self): + # Here set 2.0 to test the case there is no outputs. + # In practical use, 0.0 < score_threshold < 1.0 + self.score_threshold = 2.0 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + +class TestRetinanetDetectionOutOpNo4(TestRetinanetDetectionOutOp1): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 2 + self.max_level = 5 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + def setUp(self): + self.set_argument() + self.init_test_input() + + nmsed_outs, lod = batched_retinanet_detection_out( + self.bboxes_list, self.scores_list, self.anchors_list, self.im_info, + self.score_threshold, self.nms_threshold, self.nms_top_k, + self.keep_top_k) + nmsed_outs = np.array(nmsed_outs).astype('float32') + self.op_type = 'retinanet_detection_output' + self.inputs = { + 'BBoxes': + [('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]), + ('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3])], + 'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]), + ('s2', self.scores_list[2]), + ('s3', self.scores_list[3])], + 'Anchors': + [('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]), + ('a2', self.anchors_list[2]), ('a3', self.anchors_list[3])], + 'ImInfo': (self.im_info, [[1, ]]) + } + self.outputs = {'Out': (nmsed_outs, [lod])} + self.attrs = { + 'score_threshold': self.score_threshold, + 'nms_top_k': self.nms_top_k, + 'nms_threshold': self.nms_threshold, + 'keep_top_k': self.keep_top_k, + 'nms_eta': 1., + } + + def test_check_output(self): + self.check_output() + + +class TestRetinanetDetectionOutOpNo5(TestRetinanetDetectionOutOp1): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 100 + self.keep_top_k = 10 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + +if __name__ == '__main__': + unittest.main()