From 3305045c23969926273ed82028f61b72c031f2cb Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Wed, 19 Jun 2019 11:49:26 +0800 Subject: [PATCH] Cherry pick retinanet_target_assign_op(#17893), sigmoid_focal_loss_op(#17895) and retinanet_detection_output_op(#17896) for supporting retinanet (#18141) * test=release/1.5 Fix conflicts in test_layers.py when adding target assign operator for supporting retinanet. Cherry pick #17893 * test=release/1.5 Add sigmoid focal loss operator for supporting retinanet. Cherry pick #17895 * test=release/1.5 Add detection output operator for supporting retinanet. Cherry pick #17896 * test=release/1.5 fix wrong code style in test_layers.py when cherry pick retinanet_target_assign #17893 * test=release/1.5 Fix type error of std::pow in sigmoid_focal_loss. Cherry pick #17895 --- paddle/fluid/API.spec | 3 + .../fluid/operators/detection/CMakeLists.txt | 2 + .../retinanet_detection_output_op.cc | 566 ++++++++++++++++++ .../detection/rpn_target_assign_op.cc | 469 ++++++++++++++- .../detection/sigmoid_focal_loss_op.cc | 208 +++++++ .../detection/sigmoid_focal_loss_op.cu | 181 ++++++ .../detection/sigmoid_focal_loss_op.h | 128 ++++ python/paddle/fluid/layers/detection.py | 336 +++++++++++ .../fluid/tests/unittests/test_layers.py | 104 ++++ .../test_retinanet_detection_output.py | 412 +++++++++++++ .../unittests/test_rpn_target_assign_op.py | 159 +++++ .../unittests/test_sigmoid_focal_loss_op.py | 132 ++++ 12 files changed, 2691 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/detection/retinanet_detection_output_op.cc create mode 100644 paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc create mode 100644 paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu create mode 100644 paddle/fluid/operators/detection/sigmoid_focal_loss_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py create mode 100644 python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2640ed1815c..722422dcd4c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -348,6 +348,8 @@ paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box' paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee')) paddle.fluid.layers.detection_map (ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')), ('document', '1467d91b50c22cd52103b4aa1ee9d0a1')) paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801')) +paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595')) +paddle.fluid.layers.sigmoid_focal_loss (ArgSpec(args=['x', 'label', 'fg_num', 'gamma', 'alpha'], varargs=None, keywords=None, defaults=(2, 0.25)), ('document', 'aeac6aae100173b3fc7f102cf3023a3d')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a')) paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910')) @@ -360,6 +362,7 @@ paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'ancho paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f332fb8c5bb581bd1a6b5be450a99990')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '04384378ff00a42ade8fabd52e27cbc5')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) +paddle.fluid.layers.retinanet_detection_output (ArgSpec(args=['bboxes', 'scores', 'anchors', 'im_info', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0.05, 1000, 100, 0.3, 1.0)), ('document', '078d28607ce261a0cba2b965a79f6bb8')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5')) paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace')) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 2d655c3e3fc..f1c504d6e4b 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -35,6 +35,8 @@ detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) +detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) +detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc new file mode 100644 index 00000000000..4a6dfec12e6 --- /dev/null +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -0,0 +1,566 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE( + ctx->Inputs("BBoxes").size(), 1UL, + "Input(BBoxes) of RetinanetDetectionOutput should not be null."); + PADDLE_ENFORCE_GE( + ctx->Inputs("Scores").size(), 1UL, + "Input(Scores) of RetinanetDetectionOutput should not be null."); + PADDLE_ENFORCE_GE( + ctx->Inputs("Anchors").size(), 1UL, + "Input(Anchors) of RetinanetDetectionOutput should not be null."); + PADDLE_ENFORCE_EQ( + ctx->Inputs("BBoxes").size(), ctx->Inputs("Scores").size(), + "Input tensors(BBoxes and Scores) should have the same size."); + PADDLE_ENFORCE_EQ( + ctx->Inputs("BBoxes").size(), ctx->Inputs("Anchors").size(), + "Input tensors(BBoxes and Anchors) should have the same size."); + PADDLE_ENFORCE( + ctx->HasInput("ImInfo"), + "Input(ImInfo) of RetinanetDetectionOutput should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of RetinanetDetectionOutput should not be null."); + + auto bboxes_dims = ctx->GetInputsDim("BBoxes"); + auto scores_dims = ctx->GetInputsDim("Scores"); + auto anchors_dims = ctx->GetInputsDim("Anchors"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + + const size_t b_n = bboxes_dims.size(); + PADDLE_ENFORCE_GT(b_n, 0, "Input bbox tensors count should > 0."); + const size_t s_n = scores_dims.size(); + PADDLE_ENFORCE_GT(s_n, 0, "Input score tensors count should > 0."); + const size_t a_n = anchors_dims.size(); + PADDLE_ENFORCE_GT(a_n, 0, "Input anchor tensors count should > 0."); + + auto bbox_dims = bboxes_dims[0]; + auto score_dims = scores_dims[0]; + auto anchor_dims = anchors_dims[0]; + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(score_dims.size(), 3, + "The rank of Input(Scores) must be 3"); + PADDLE_ENFORCE_EQ(bbox_dims.size(), 3, + "The rank of Input(BBoxes) must be 3"); + PADDLE_ENFORCE_EQ(anchor_dims.size(), 2, + "The rank of Input(Anchors) must be 2"); + PADDLE_ENFORCE(bbox_dims[2] == 4, + "The last dimension of Input(BBoxes) must be 4, " + "represents the layout of coordinate " + "[xmin, ymin, xmax, ymax]"); + PADDLE_ENFORCE_EQ(bbox_dims[1], score_dims[1], + "The 2nd dimension of Input(BBoxes) must be equal to " + "2nd dimension of Input(Scores), which represents the " + "number of the predicted boxes."); + + PADDLE_ENFORCE_EQ(anchor_dims[0], bbox_dims[1], + "The 1st dimension of Input(Anchors) must be equal to " + "2nd dimension of Input(BBoxes), which represents the " + "number of the predicted boxes."); + PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, + "The rank of Input(ImInfo) must be 2."); + } + // Here the box_dims[0] is not the real dimension of output. + // It will be rewritten in the computing kernel. + ctx->SetOutputDim("Out", {bbox_dims[1], bbox_dims[2] + 2}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]); + + return framework::OpKernelType(input_data_type, + platform::CPUPlace()); // ctx.GetPlace()); + } +}; + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +bool SortScoreTwoPairDescend(const std::pair>& pair1, + const std::pair>& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, const T threshold, int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +template +static inline T BBoxArea(const std::vector& box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const std::vector& box1, + const std::vector& box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +class RetinanetDetectionOutputKernel : public framework::OpKernel { + public: + void NMSFast(const std::vector>& cls_dets, + const T nms_threshold, const T eta, + std::vector* selected_indices) const { + int64_t num_boxes = cls_dets.size(); + std::vector> sorted_indices; + for (int64_t i = 0; i < num_boxes; ++i) { + sorted_indices.push_back(std::make_pair(cls_dets[i][4], i)); + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices.begin(), sorted_indices.end(), + SortScorePairDescend); + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = T(0.); + + overlap = JaccardOverlap(cls_dets[idx], cls_dets[kept_idx], false); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } + } + + void DeltaScoreToPrediction( + const std::vector& bboxes_data, const std::vector& anchors_data, + T im_height, T im_width, T im_scale, int class_num, + const std::vector>& sorted_indices, + std::map>>* preds) const { + im_height = static_cast(round(im_height / im_scale)); + im_width = static_cast(round(im_width / im_scale)); + T zero(0); + int i = 0; + for (const auto& it : sorted_indices) { + T score = it.first; + int idx = it.second; + int a = idx / class_num; + int c = idx % class_num; + + int box_offset = a * 4; + T anchor_box_width = + anchors_data[box_offset + 2] - anchors_data[box_offset] + 1; + T anchor_box_height = + anchors_data[box_offset + 3] - anchors_data[box_offset + 1] + 1; + T anchor_box_center_x = anchors_data[box_offset] + anchor_box_width / 2; + T anchor_box_center_y = + anchors_data[box_offset + 1] + anchor_box_height / 2; + T target_box_center_x = 0, target_box_center_y = 0; + T target_box_width = 0, target_box_height = 0; + target_box_center_x = + bboxes_data[box_offset] * anchor_box_width + anchor_box_center_x; + target_box_center_y = + bboxes_data[box_offset + 1] * anchor_box_height + anchor_box_center_y; + target_box_width = + std::exp(bboxes_data[box_offset + 2]) * anchor_box_width; + target_box_height = + std::exp(bboxes_data[box_offset + 3]) * anchor_box_height; + T pred_box_xmin = target_box_center_x - target_box_width / 2; + T pred_box_ymin = target_box_center_y - target_box_height / 2; + T pred_box_xmax = target_box_center_x + target_box_width / 2 - 1; + T pred_box_ymax = target_box_center_y + target_box_height / 2 - 1; + pred_box_xmin = pred_box_xmin / im_scale; + pred_box_ymin = pred_box_ymin / im_scale; + pred_box_xmax = pred_box_xmax / im_scale; + pred_box_ymax = pred_box_ymax / im_scale; + + pred_box_xmin = std::max(std::min(pred_box_xmin, im_width - 1), zero); + pred_box_ymin = std::max(std::min(pred_box_ymin, im_height - 1), zero); + pred_box_xmax = std::max(std::min(pred_box_xmax, im_width - 1), zero); + pred_box_ymax = std::max(std::min(pred_box_ymax, im_height - 1), zero); + + std::vector one_pred; + one_pred.push_back(pred_box_xmin); + one_pred.push_back(pred_box_ymin); + one_pred.push_back(pred_box_xmax); + one_pred.push_back(pred_box_ymax); + one_pred.push_back(score); + (*preds)[c].push_back(one_pred); + i++; + } + } + + void MultiClassNMS(const std::map>>& preds, + int class_num, const int keep_top_k, const T nms_threshold, + const T nms_eta, std::vector>* nmsed_out, + int* num_nmsed_out) const { + std::map> indices; + int num_det = 0; + for (int c = 0; c < class_num; ++c) { + if (static_cast(preds.count(c))) { + const std::vector> cls_dets = preds.at(c); + NMSFast(cls_dets, nms_threshold, nms_eta, &(indices[c])); + num_det += indices[c].size(); + } + } + + std::vector>> score_index_pairs; + for (const auto& it : indices) { + int label = it.first; + const std::vector& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + score_index_pairs.push_back(std::make_pair(preds.at(label)[idx][4], + std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScoreTwoPairDescend); + if (num_det > keep_top_k) { + score_index_pairs.resize(keep_top_k); + } + + // Store the new indices. + std::map> new_indices; + for (const auto& it : score_index_pairs) { + int label = it.second.first; + int idx = it.second.second; + std::vector one_pred; + one_pred.push_back(label); + one_pred.push_back(preds.at(label)[idx][4]); + one_pred.push_back(preds.at(label)[idx][0]); + one_pred.push_back(preds.at(label)[idx][1]); + one_pred.push_back(preds.at(label)[idx][2]); + one_pred.push_back(preds.at(label)[idx][3]); + nmsed_out->push_back(one_pred); + } + + *num_nmsed_out = (num_det > keep_top_k ? keep_top_k : num_det); + } + + void RetinanetDetectionOutput(const framework::ExecutionContext& ctx, + const std::vector& scores, + const std::vector& bboxes, + const std::vector& anchors, + const Tensor& im_info, + std::vector>* nmsed_out, + int* num_nmsed_out) const { + int64_t nms_top_k = ctx.Attr("nms_top_k"); + int64_t keep_top_k = ctx.Attr("keep_top_k"); + T nms_threshold = static_cast(ctx.Attr("nms_threshold")); + T nms_eta = static_cast(ctx.Attr("nms_eta")); + T score_threshold = static_cast(ctx.Attr("score_threshold")); + + int64_t class_num = scores[0].dims()[1]; + std::map>> preds; + for (size_t l = 0; l < scores.size(); ++l) { + // Fetch per level score + Tensor scores_per_level = scores[l]; + // Fetch per level bbox + Tensor bboxes_per_level = bboxes[l]; + // Fetch per level anchor + Tensor anchors_per_level = anchors[l]; + + int64_t scores_num = scores_per_level.numel(); + int64_t bboxes_num = bboxes_per_level.numel(); + std::vector scores_data(scores_num); + std::vector bboxes_data(bboxes_num); + std::vector anchors_data(bboxes_num); + std::copy_n(scores_per_level.data(), scores_num, scores_data.begin()); + std::copy_n(bboxes_per_level.data(), bboxes_num, bboxes_data.begin()); + std::copy_n(anchors_per_level.data(), bboxes_num, + anchors_data.begin()); + std::vector> sorted_indices; + + // For the highest level, we take the threshold 0.0 + T threshold = (l < (scores.size() - 1) ? score_threshold : 0.0); + GetMaxScoreIndex(scores_data, threshold, nms_top_k, &sorted_indices); + auto* im_info_data = im_info.data(); + auto im_height = im_info_data[0]; + auto im_width = im_info_data[1]; + auto im_scale = im_info_data[2]; + DeltaScoreToPrediction(bboxes_data, anchors_data, im_height, im_width, + im_scale, class_num, sorted_indices, &preds); + } + + MultiClassNMS(preds, class_num, keep_top_k, nms_threshold, nms_eta, + nmsed_out, num_nmsed_out); + } + + void MultiClassOutput(const platform::DeviceContext& ctx, + const std::vector>& nmsed_out, + Tensor* outs) const { + auto* odata = outs->data(); + int count = 0; + int64_t out_dim = 6; + for (size_t i = 0; i < nmsed_out.size(); ++i) { + odata[count * out_dim] = nmsed_out[i][0] + 1; // label + odata[count * out_dim + 1] = nmsed_out[i][1]; // score + odata[count * out_dim + 2] = nmsed_out[i][2]; // xmin + odata[count * out_dim + 3] = nmsed_out[i][3]; // xmin + odata[count * out_dim + 4] = nmsed_out[i][4]; // xmin + odata[count * out_dim + 5] = nmsed_out[i][5]; // xmin + count++; + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto boxes = ctx.MultiInput("BBoxes"); + auto scores = ctx.MultiInput("Scores"); + auto anchors = ctx.MultiInput("Anchors"); + auto* im_info = ctx.Input("ImInfo"); + auto* outs = ctx.Output("Out"); + + std::vector boxes_list(boxes.size()); + std::vector scores_list(scores.size()); + std::vector anchors_list(anchors.size()); + for (size_t j = 0; j < boxes_list.size(); ++j) { + boxes_list[j] = *boxes[j]; + scores_list[j] = *scores[j]; + anchors_list[j] = *anchors[j]; + } + auto score_dims = scores_list[0].dims(); + int64_t batch_size = score_dims[0]; + auto box_dims = boxes_list[0].dims(); + int64_t box_dim = box_dims[2]; + int64_t out_dim = box_dim + 2; + + auto& dev_ctx = ctx.template device_context(); + + std::vector>> all_nmsed_out; + std::vector batch_starts = {0}; + for (int i = 0; i < batch_size; ++i) { + int num_nmsed_out = 0; + std::vector box_per_batch_list(boxes_list.size()); + std::vector score_per_batch_list(scores_list.size()); + for (size_t j = 0; j < boxes_list.size(); ++j) { + auto score_dims = scores_list[j].dims(); + score_per_batch_list[j] = scores_list[j].Slice(i, i + 1); + score_per_batch_list[j].Resize({score_dims[1], score_dims[2]}); + box_per_batch_list[j] = boxes_list[j].Slice(i, i + 1); + box_per_batch_list[j].Resize({score_dims[1], box_dim}); + } + Tensor im_info_slice = im_info->Slice(i, i + 1); + + std::vector> nmsed_out; + RetinanetDetectionOutput(ctx, score_per_batch_list, box_per_batch_list, + anchors_list, im_info_slice, &nmsed_out, + &num_nmsed_out); + all_nmsed_out.push_back(nmsed_out); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + int num_kept = batch_starts.back(); + if (num_kept == 0) { + outs->Resize({0, out_dim}); + } else { + outs->mutable_data({num_kept, out_dim}, ctx.GetPlace()); + for (int i = 0; i < batch_size; ++i) { + int64_t s = batch_starts[i]; + int64_t e = batch_starts[i + 1]; + if (e > s) { + Tensor out = outs->Slice(s, e); + MultiClassOutput(dev_ctx, all_nmsed_out[i], &out); + } + } + } + + framework::LoD lod; + lod.emplace_back(batch_starts); + + outs->set_lod(lod); + } +}; + +class RetinanetDetectionOutputOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("BBoxes", + "(List) A list of tensors from multiple FPN levels. Each " + "element is a 3-D Tensor with shape [N, Mi, 4] represents the " + "predicted locations of Mi bounding boxes, N is the batch size. " + "Mi is the number of bounding boxes from i-th FPN level. Each " + "bounding box has four coordinate values and the layout is " + "[xmin, ymin, xmax, ymax].") + .AsDuplicable(); + AddInput("Scores", + "(List) A list of tensors from multiple FPN levels. Each " + "element is a 3-D Tensor with shape [N, Mi, C] represents the " + "predicted confidence from its FPN level. N is the batch size, " + "C is the class number (excluding background), Mi is the number " + "of bounding boxes from i-th FPN level. For each bounding box, " + "there are total C scores.") + .AsDuplicable(); + AddInput("Anchors", + "(List) A list of tensors from multiple FPN levels. Each" + "element is a 2-D Tensor with shape [Mi, 4] represents the " + "locations of Mi anchor boxes from i-th FPN level. Each " + "bounding box has four coordinate values and the layout is " + "[xmin, ymin, xmax, ymax].") + .AsDuplicable(); + AddInput("ImInfo", + "(LoDTensor) A 2-D LoDTensor with shape [N, 3] represents the " + "image information. N is the batch size, each image information " + "includes height, width and scale."); + AddAttr("score_threshold", + "(float) " + "Threshold to filter out bounding boxes with a confidence " + "score."); + AddAttr("nms_top_k", + "(int64_t) " + "Maximum number of detections per FPN layer to be kept " + "according to the confidence before NMS."); + AddAttr("nms_threshold", + "(float) " + "The threshold to be used in NMS."); + AddAttr("nms_eta", + "(float) " + "The parameter for adaptive NMS."); + AddAttr( + "keep_top_k", + "(int64_t) " + "Number of total bounding boxes to be kept per image after NMS " + "step."); + AddOutput("Out", + "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax]" + "No is the total number of detections in this mini-batch." + "For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected bbox."); + AddComment(R"DOC( +This operator is to decode boxes and scores from each FPN layer and do +multi-class non maximum suppression (NMS) on merged predictions. + +Top-scoring predictions per FPN layer are decoded with the anchor +information. This operator greedily selects a subset of detection bounding +boxes from each FPN layer that have high scores larger than score_threshold, +if providing this threshold, then selects the largest nms_top_k confidences +scores per FPN layer, if nms_top_k is larger than -1. +The decoding schema is described below: + +ox = (pw * pxv * tx * + px) - tw / 2 + +oy = (ph * pyv * ty * + py) - th / 2 + +ow = exp(pwv * tw) * pw + tw / 2 + +oh = exp(phv * th) * ph + th / 2 + +where `tx`, `ty`, `tw`, `th` denote the predicted box's center coordinates, width +and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the +anchor's center coordinates, width and height. `pxv`, `pyv`, `pwv`, +`phv` denote the variance of the anchor box and `ox`, `oy`, `ow`, `oh` denote the +decoded coordinates, width and height. + +Then the top decoded prediction from all levels are merged followed by NMS. +In the NMS step, this operator prunes away boxes that have high IOU +(intersection over union) overlap with already selected boxes by adaptive +threshold NMS based on parameters of nms_threshold and nms_eta. +After NMS step, at most keep_top_k number of total bounding boxes are to be kept +per image if keep_top_k is larger than -1. +This operator support multi-class and batched inputs. It applying NMS +independently for each class. The outputs is a 2-D LoDTenosr, for each +image, the offsets in first dimension of LoDTensor are called LoD, the number +of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0, +means there is no detected bounding box for this image. If there is no detected boxes +for all images, all the elements in LoD are set to 0, and the output tensor is +empty (None). +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(retinanet_detection_output, ops::RetinanetDetectionOutputOp, + ops::RetinanetDetectionOutputOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(retinanet_detection_output, + ops::RetinanetDetectionOutputKernel, + ops::RetinanetDetectionOutputKernel); diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 0b8053e8d03..338954346c5 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -202,21 +202,32 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, } // Reservoir Sampling - int fg_num = static_cast(rpn_fg_fraction * rpn_batch_size_per_im); - ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); + int fg_num = 0; + if (rpn_fg_fraction > 0 && rpn_batch_size_per_im > 0) { + fg_num = static_cast(rpn_fg_fraction * rpn_batch_size_per_im); + ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); + } else { + fg_num = static_cast(fg_inds_fake.size()); + } int fg_fake_num = static_cast(fg_inds_fake.size()); for (int64_t i = 0; i < fg_fake_num; ++i) { target_label[fg_inds_fake[i]] = 1; } - int bg_num = rpn_batch_size_per_im - fg_fake_num; for (int64_t i = 0; i < anchor_num; ++i) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { bg_inds_fake.push_back(i); } } - ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); - bg_num = static_cast(bg_inds_fake.size()); + int bg_num = 0; + if (rpn_fg_fraction > 0 && rpn_batch_size_per_im > 0) { + bg_num = rpn_batch_size_per_im - fg_fake_num; + ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); + bg_num = static_cast(bg_inds_fake.size()); + } else { + bg_num = static_cast(bg_inds_fake.size()); + } + int fake_num = 0; for (int64_t i = 0; i < bg_num; ++i) { // fg fake found @@ -492,9 +503,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Anchor", "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); AddInput("GtBoxes", - "(LoDTensor) input groud-truth bbox with shape [K, 4]."); + "(LoDTensor) input ground-truth bbox with shape [K, 4]."); AddInput("IsCrowd", - "(LoDTensor) input which indicates groud-truth is crowd."); + "(LoDTensor) input which indicates ground-truth is crowd."); AddInput("ImInfo", "(LoDTensor) input image information with shape [N, 3]. " "N is the batch size, each image information includes height, " @@ -536,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { "ScoreIndex", "(Tensor), The indexes of foreground and background anchors in all " "RPN anchors(The rest anchors are ignored). The shape of the " - "ScoreIndex is [F + B], F and B are sampled foreground and backgroud " + "ScoreIndex is [F + B], F and B are sampled foreground and background " " number."); AddOutput("TargetBBox", "(Tensor), The target bbox deltas with shape " @@ -544,7 +555,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput( "TargetLabel", "(Tensor), The target labels of each anchor with shape " - "[F + B, 1], F and B are sampled foreground and backgroud number."); + "[F + B, 1], F and B are sampled foreground and background number."); AddOutput("BBoxInsideWeight", "(Tensor), The bbox inside weight with shape " "[F, 4], F is the sampled foreground number."); @@ -573,6 +584,440 @@ negative do not contribute to the training objective. } }; +class RetinanetTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Anchor", + "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); + AddInput("GtBoxes", + "(LoDTensor) input ground-truth bbox with shape [K, 4]."); + AddInput("GtLabels", + "(LoDTensor) input ground-truth label with shape [K, 1]."); + AddInput("IsCrowd", + "(LoDTensor) input which indicates ground-truth is crowd."); + AddInput("ImInfo", + "(LoDTensor) input image information with shape [N, 3]. " + "N is the batch size, each image information includes height, " + "width and scale."); + AddAttr( + "positive_overlap", + "Minimum overlap required between an anchor and ground-truth " + "box for the (anchor, gt box) pair to be a positive example.") + .SetDefault(0.5); + AddAttr( + "negative_overlap", + "Maximum overlap allowed between an anchor and ground-truth " + "box for the (anchor, gt box) pair to be a negative examples.") + .SetDefault(0.4); + AddOutput( + "LocationIndex", + "(Tensor), The indexes of foreground anchors in all anchors, the " + "shape of the LocationIndex is [F], F depends on the value of input " + "tensor and attributes."); + AddOutput( + "ScoreIndex", + "(Tensor), The indexes of foreground and background anchors in all " + "RPN anchors(The rest anchors are ignored). The shape of the " + "ScoreIndex is [F + B], F and B are foreground and background " + " number."); + AddOutput("TargetBBox", + "(Tensor), The target bbox deltas with shape " + "[F, 4], F is the foreground number."); + AddOutput("TargetLabel", + "(Tensor), The target labels of each anchor with shape " + "[F + B, 1], F and B are foreground and background number."); + AddOutput("BBoxInsideWeight", + "(Tensor), The bbox inside weight with shape " + "[F, 4], F is the foreground number."); + AddOutput("ForegroundNumber", + "(Tensor), The foreground number. " + "[1, 1]."); + AddComment(R"DOC( + This layer can be, for given the Intersection-over-Union (IoU) overlap + between anchors and ground truth boxes, to assign classification and + regression targets to each anchor, these target labels are used for + train retinanet. + + Every anchor is assigned with a length C one-hot vector of + classification targets, and a 4-vector of box regression targets, + where C is the class number. The assignment rules are as followed: + + 1. Anchors are assigned to ground-truth boxes when: (i) it has the highest + IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher + than positive_overlap(0.5) with any ground-truth box. + + 2. Anchors are assigned to background when its IoU ratio is lower than + negative_overlap (0.4) for all ground-truth boxes. + + When an anchor is assigned with a ground-truth box which is the i-th category, + the i-th entry in its C vector of targets is set to 1 and all other entries + are set to 0. When an anchor is assigned with background, all entries are set + to 0. Anchors that are not assigned do not contribute to the training + objective. The regression targets are the encoded ground-truth boxes + associated with the assigned anchors. + +)DOC"); + } +}; + +class RetinanetTargetAssignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Anchor"), + "Input(Anchor) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("GtBoxes"), + "Input(GtBoxes) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("GtLabels"), + "Input(GtLabels) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("IsCrowd"), + "Input(Anchor) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("ImInfo"), + "Input(ImInfo) of RetinanetTargetAssignOp should not be null"); + + PADDLE_ENFORCE( + ctx->HasOutput("LocationIndex"), + "Output(LocationIndex) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("ScoreIndex"), + "Output(ScoreIndex) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("TargetLabel"), + "Output(TargetLabel) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("TargetBBox"), + "Output(TargetBBox) of RetinanetTargetAssignOp should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("BBoxInsideWeight"), + "Output(BBoxInsideWeight) of RetinanetTargetAssignOp should " + "not be null"); + PADDLE_ENFORCE(ctx->HasOutput("ForegroundNumber"), + "Output(ForegroundNumber) of RetinanetTargetAssignOp should " + "not be null"); + + auto anchor_dims = ctx->GetInputDim("Anchor"); + auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); + auto gt_labels_dims = ctx->GetInputDim("GtLabels"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + + PADDLE_ENFORCE_EQ(anchor_dims.size(), 2, + "The rank of Input(Anchor) must be 2."); + PADDLE_ENFORCE_EQ(gt_boxes_dims.size(), 2, + "The rank of Input(GtBoxes) must be 2."); + PADDLE_ENFORCE_EQ(gt_labels_dims.size(), 2, + "The rank of Input(GtLabels) must be 2."); + PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, + "The rank of Input(ImInfo) must be 2."); + + ctx->SetOutputDim("LocationIndex", {gt_labels_dims[0]}); + ctx->SetOutputDim("ScoreIndex", {gt_labels_dims[0]}); + ctx->SetOutputDim("TargetBBox", {gt_labels_dims[0], 4}); + ctx->SetOutputDim("TargetLabel", {gt_labels_dims[0], 1}); + ctx->SetOutputDim("BBoxInsideWeight", {gt_labels_dims[0], 4}); + ctx->SetOutputDim("ForegroundNumber", {gt_labels_dims[0], 1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input("Anchor")->type(), + platform::CPUPlace()); + } +}; + +template +std::vector FilterCrowdGtBoxLabel( + const platform::CPUDeviceContext& context, Tensor* gt_boxes, + Tensor* gt_labels, Tensor* is_crowd) { + int gt_num = gt_boxes->dims()[0]; + std::vector not_crowd_inds; + auto* is_crowd_data = is_crowd->data(); + for (int i = 0; i < gt_num; ++i) { + if (is_crowd_data[i] == 0) { + not_crowd_inds.emplace_back(i); + } + } + int ncrowd_num = not_crowd_inds.size(); + Tensor ncrowd_gt_boxes, ncrowd_gt_labels; + T* ncrowd_gt_boxes_data = + ncrowd_gt_boxes.mutable_data({ncrowd_num, 4}, context.GetPlace()); + int* ncrowd_gt_labels_data = + ncrowd_gt_labels.mutable_data({ncrowd_num, 1}, context.GetPlace()); + Gather(gt_boxes->data(), 4, not_crowd_inds.data(), ncrowd_num, + ncrowd_gt_boxes_data); + Gather(gt_labels->data(), 1, not_crowd_inds.data(), ncrowd_num, + ncrowd_gt_labels_data); + std::vector res; + res.emplace_back(ncrowd_gt_boxes); + res.emplace_back(ncrowd_gt_labels); + return res; +} + +template +std::vector GetAllFgBgGt(const platform::CPUDeviceContext& ctx, + const Tensor& anchor_by_gt_overlap, + const Tensor& ncrowd_gt_labels, + const float positive_overlap, + const float negative_overlap, + std::minstd_rand engine) { + auto* anchor_by_gt_overlap_data = anchor_by_gt_overlap.data(); + int anchor_num = anchor_by_gt_overlap.dims()[0]; + int gt_num = anchor_by_gt_overlap.dims()[1]; + + std::vector fg_inds; + std::vector bg_inds; + std::vector gt_inds; + std::vector tgt_lbl; + std::vector fg_fake; + std::vector bbox_inside_weight; + // Calculate the max IoU between anchors and gt boxes + // Map from anchor to gt box that has highest overlap + auto place = ctx.GetPlace(); + Tensor anchor_to_gt_max, anchor_to_gt_argmax, gt_to_anchor_max; + anchor_to_gt_max.mutable_data({anchor_num}, place); + int* argmax = anchor_to_gt_argmax.mutable_data({anchor_num}, place); + gt_to_anchor_max.mutable_data({gt_num}, place); + + auto anchor_by_gt_overlap_et = + framework::EigenMatrix::From(anchor_by_gt_overlap); + auto anchor_to_gt_max_et = + framework::EigenVector::Flatten(anchor_to_gt_max); + auto gt_to_anchor_max_et = + framework::EigenVector::Flatten(gt_to_anchor_max); + auto anchor_to_gt_argmax_et = + framework::EigenVector::Flatten(anchor_to_gt_argmax); + anchor_to_gt_max_et = + anchor_by_gt_overlap_et.maximum(Eigen::DSizes(1)); + anchor_to_gt_argmax_et = + anchor_by_gt_overlap_et.argmax(1).template cast(); + gt_to_anchor_max_et = + anchor_by_gt_overlap_et.maximum(Eigen::DSizes(0)); + + ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, -1, + -1, positive_overlap, negative_overlap, &fg_inds, &bg_inds, + &tgt_lbl, &fg_fake, &bbox_inside_weight, engine, false); + const int* gt_labels_data = ncrowd_gt_labels.data(); + int64_t fg_num = fg_inds.size(); + for (int64_t i = 0; i < fg_num; ++i) { + int gt_idx = argmax[fg_inds[i]]; + tgt_lbl[i] = gt_labels_data[gt_idx]; + } + + int bg_num = bg_inds.size(); + int fg_fake_num = fg_fake.size(); + gt_inds.reserve(fg_fake_num); + for (int i = 0; i < fg_fake_num; ++i) { + gt_inds.emplace_back(argmax[fg_fake[i]]); + } + + Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t; + Tensor fg_num_t; + int* loc_index_data = loc_index_t.mutable_data({fg_fake_num}, place); + int* score_index_data = + score_index_t.mutable_data({fg_num + bg_num}, place); + int* tgt_lbl_data = tgt_lbl_t.mutable_data({fg_num + bg_num}, place); + int* gt_inds_data = gt_inds_t.mutable_data({fg_fake_num}, place); + int* fg_num_data = fg_num_t.mutable_data({1}, place); + T* bbox_inside_weight_data = + bbox_inside_weight_t.mutable_data({fg_fake_num, 4}, place); + std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data); + std::copy(fg_inds.begin(), fg_inds.end(), score_index_data); + std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num); + std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data); + std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data); + std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(), + bbox_inside_weight_data); + fg_num_data[0] = fg_fake.size() + 1; + std::vector loc_score_tgtlbl_gt; + loc_score_tgtlbl_gt.emplace_back(loc_index_t); + loc_score_tgtlbl_gt.emplace_back(score_index_t); + loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t); + loc_score_tgtlbl_gt.emplace_back(gt_inds_t); + loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t); + loc_score_tgtlbl_gt.emplace_back(fg_num_t); + + return loc_score_tgtlbl_gt; +} + +template +class RetinanetTargetAssignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* anchor = context.Input("Anchor"); // (H*W*A) * 4 + auto* gt_boxes = context.Input("GtBoxes"); + auto* gt_labels = context.Input("GtLabels"); + auto* is_crowd = context.Input("IsCrowd"); + auto* im_info = context.Input("ImInfo"); + + auto* loc_index = context.Output("LocationIndex"); + auto* score_index = context.Output("ScoreIndex"); + auto* tgt_bbox = context.Output("TargetBBox"); + auto* tgt_lbl = context.Output("TargetLabel"); + auto* bbox_inside_weight = context.Output("BBoxInsideWeight"); + auto* fg_num = context.Output("ForegroundNumber"); + + PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, + "RetinanetTargetAssignOp gt_boxes needs 1 level of LoD"); + PADDLE_ENFORCE_EQ(gt_labels->lod().size(), 1UL, + "RetinanetTargetAssignOp gt_boxes needs 1 level of LoD"); + PADDLE_ENFORCE_EQ(is_crowd->lod().size(), 1UL, + "RetinanetTargetAssignOp is_crowd needs 1 level of LoD"); + + int64_t anchor_num = static_cast(anchor->dims()[0]); + int64_t batch_num = static_cast(gt_boxes->lod().back().size() - 1); + + float positive_overlap = context.Attr("positive_overlap"); + float negative_overlap = context.Attr("negative_overlap"); + + int64_t max_num = batch_num * anchor_num; + auto place = context.GetPlace(); + + loc_index->mutable_data({max_num}, place); + score_index->mutable_data({max_num}, place); + tgt_bbox->mutable_data({max_num, 4}, place); + tgt_lbl->mutable_data({max_num, 1}, place); + bbox_inside_weight->mutable_data({max_num, 4}, place); + fg_num->mutable_data({batch_num, 1}, place); + auto& dev_ctx = context.device_context(); + + std::random_device rnd; + std::minstd_rand engine; + int seed = rnd(); + engine.seed(seed); + + framework::LoD lod_loc, loc_score, lod_fg; + std::vector lod0_loc(1, 0); + std::vector lod0_score(1, 0); + std::vector lod0_fg(1, 0); + + int total_loc_num = 0; + int total_score_num = 0; + int total_fg_num = 0; + auto gt_boxes_lod = gt_boxes->lod().back(); + auto gt_labels_lod = gt_labels->lod().back(); + auto is_crowd_lod = is_crowd->lod().back(); + for (int i = 0; i < batch_num; ++i) { + Tensor gt_boxes_slice = + gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]); + Tensor gt_labels_slice = + gt_labels->Slice(gt_labels_lod[i], gt_labels_lod[i + 1]); + Tensor is_crowd_slice = + is_crowd->Slice(is_crowd_lod[i], is_crowd_lod[i + 1]); + Tensor im_info_slice = im_info->Slice(i, i + 1); + auto* im_info_data = im_info_slice.data(); + auto im_height = im_info_data[0]; + auto im_width = im_info_data[1]; + auto im_scale = im_info_data[2]; + + // Filter straddle anchor + std::vector filter_output = + FilterStraddleAnchor(dev_ctx, anchor, -1, im_height, im_width); + Tensor inds_inside = filter_output[0]; + Tensor inside_anchor = filter_output[1]; + + // Filter crowd gt + std::vector ncrowd_output = FilterCrowdGtBoxLabel( + dev_ctx, >_boxes_slice, >_labels_slice, &is_crowd_slice); + Tensor ncrowd_gt_boxes = ncrowd_output[0]; + Tensor ncrowd_gt_labels = ncrowd_output[1]; + + auto ncrowd_gt_boxes_et = + framework::EigenTensor::From(ncrowd_gt_boxes); + ncrowd_gt_boxes_et = ncrowd_gt_boxes_et * im_scale; + + Tensor anchor_by_gt_overlap; + anchor_by_gt_overlap.mutable_data( + {inside_anchor.dims()[0], ncrowd_gt_boxes.dims()[0]}, place); + BboxOverlaps(inside_anchor, ncrowd_gt_boxes, &anchor_by_gt_overlap); + + auto loc_score_tgtlbl_gt = + GetAllFgBgGt(dev_ctx, anchor_by_gt_overlap, ncrowd_gt_labels, + positive_overlap, negative_overlap, engine); + + Tensor sampled_loc_index = loc_score_tgtlbl_gt[0]; + Tensor sampled_score_index = loc_score_tgtlbl_gt[1]; + Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2]; + Tensor sampled_gt_index = loc_score_tgtlbl_gt[3]; + Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4]; + Tensor sampled_fg_num = loc_score_tgtlbl_gt[5]; + + int loc_num = sampled_loc_index.dims()[0]; + int score_num = sampled_score_index.dims()[0]; + // unmap to all anchor + Tensor sampled_loc_index_unmap, sampled_score_index_unmap; + sampled_loc_index_unmap.mutable_data({loc_num}, place); + sampled_score_index_unmap.mutable_data({score_num}, place); + Gather(inds_inside.data(), 1, sampled_loc_index.data(), + loc_num, sampled_loc_index_unmap.data()); + Gather(inds_inside.data(), 1, sampled_score_index.data(), + score_num, sampled_score_index_unmap.data()); + + // get target bbox deltas + Tensor sampled_anchor, sampled_gt, sampled_tgt_bbox; + auto* sampled_anchor_data = + sampled_anchor.mutable_data({loc_num, 4}, place); + auto* sampled_gt_data = sampled_gt.mutable_data({loc_num, 4}, place); + Gather(anchor->data(), 4, sampled_loc_index_unmap.data(), + loc_num, sampled_anchor_data); + Gather(ncrowd_gt_boxes.data(), 4, sampled_gt_index.data(), + loc_num, sampled_gt_data); + sampled_tgt_bbox.mutable_data({loc_num, 4}, place); + BoxToDelta(loc_num, sampled_anchor, sampled_gt, nullptr, false, + &sampled_tgt_bbox); + + // Add anchor offset + int anchor_offset = i * anchor_num; + auto sampled_loc_index_unmap_et = + framework::EigenTensor::From(sampled_loc_index_unmap); + sampled_loc_index_unmap_et = sampled_loc_index_unmap_et + anchor_offset; + auto sampled_score_index_unmap_et = + framework::EigenTensor::From(sampled_score_index_unmap); + sampled_score_index_unmap_et = + sampled_score_index_unmap_et + anchor_offset; + AppendRpns(loc_index, total_loc_num, &sampled_loc_index_unmap); + AppendRpns(score_index, total_score_num, &sampled_score_index_unmap); + AppendRpns(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox); + AppendRpns(tgt_lbl, total_score_num, &sampled_tgtlbl); + AppendRpns(bbox_inside_weight, total_loc_num * 4, + &sampled_bbox_inside_weight); + AppendRpns(fg_num, total_fg_num, &sampled_fg_num); + + total_loc_num += loc_num; + total_score_num += score_num; + total_fg_num += 1; + lod0_loc.emplace_back(total_loc_num); + lod0_score.emplace_back(total_score_num); + lod0_fg.emplace_back(total_fg_num); + } + + PADDLE_ENFORCE_LE(total_loc_num, max_num); + PADDLE_ENFORCE_LE(total_score_num, max_num); + PADDLE_ENFORCE_LE(total_fg_num, batch_num); + + lod_loc.emplace_back(lod0_loc); + loc_score.emplace_back(lod0_score); + lod_fg.emplace_back(lod0_fg); + loc_index->set_lod(lod_loc); + score_index->set_lod(loc_score); + tgt_bbox->set_lod(lod_loc); + tgt_lbl->set_lod(loc_score); + bbox_inside_weight->set_lod(lod_loc); + fg_num->set_lod(lod_fg); + loc_index->Resize({total_loc_num}); + score_index->Resize({total_score_num}); + tgt_bbox->Resize({total_loc_num, 4}); + tgt_lbl->Resize({total_score_num, 1}); + bbox_inside_weight->Resize({total_loc_num, 4}); + fg_num->Resize({total_fg_num, 1}); + } +}; + } // namespace operators } // namespace paddle @@ -582,3 +1027,9 @@ REGISTER_OPERATOR(rpn_target_assign, ops::RpnTargetAssignOp, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel, ops::RpnTargetAssignKernel); +REGISTER_OPERATOR(retinanet_target_assign, ops::RetinanetTargetAssignOp, + ops::RetinanetTargetAssignOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(retinanet_target_assign, + ops::RetinanetTargetAssignKernel, + ops::RetinanetTargetAssignKernel); diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc new file mode 100644 index 00000000000..50ff3cb120e --- /dev/null +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc @@ -0,0 +1,208 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SigmoidFocalLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Label"); + auto fg_dims = ctx->GetInputDim("FgNum"); + + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, labels_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(labels_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } + + PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL, + "The last dimension of input(Label) should be 1."); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Label"); + auto fg_dims = ctx->GetInputDim("FgNum"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, labels_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(labels_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape."); + + PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL, + "The last dimension of input(Label) should be 1."); + + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(dout_dims, 0, rank), + "Input(X) and Input(Out@Grad) shall have the same shape."); + } + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class SigmoidFocalLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape [N, D], " + "where N is the batch size and D is the number of classes " + "(excluding background). This input is a tensor of logits " + "computed by the previous operator."); + AddInput("Label", + "(Tensor, default Tensor), a 2-D tensor with shape [N, 1]. " + "This input is a tensor of probabilistic labels."); + AddInput("FgNum", + "(Tensor, default Tensor), a 1-D tensor with shape [1]. " + "This input is the number of foreground."); + AddOutput( + "Out", + "(Tensor, default Tensor), a 2-D tensor with shape [N, D]. " + "This output is the focal loss."); + AddAttr( + "gamma", + "Hyper-parameter of sigmoid focal loss op, which is to balance the " + "easy and hard examples. " + "A float scalar with default value 2.0.") + .SetDefault(2.0); + AddAttr( + "alpha", + "Hyper-parameter of sigmoid focal loss op, which is to balance the " + "positive and negative examples. " + "A float scalar with default value 0.5.") + .SetDefault(0.25); + AddComment(R"DOC( +Sigmoid Focal Loss Operator. + +Focal loss is used to address the foreground-background class imbalance existed +on the training phase of one-stage detectors. This operator computes the sigmoid +value for each element in the input tensor, after which focal loss is measured. + +The focal loss is given as follows: + +$$Loss_j = (-Label_j * alpha * \pow(1 - \sigma(X_j), gamma) * \log(\sigma(X_j)) - +(1 - Labels_j) * (1 - alpha) * \pow(\sigma(X_j), gamma) * \log(1 - \sigma(X_j))) +/ FgNum, j = 1,...,K$$ + +We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$. + +)DOC"); + } +}; + +class SigmoidFocalLossGradOpDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sigmoid_focal_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Label", Input("Label")); + op->SetInput("FgNum", Input("FgNum")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp, + ops::SigmoidFocalLossOpMaker, + ops::SigmoidFocalLossGradOpDescMaker); +REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp); +REGISTER_OP_CPU_KERNEL( + sigmoid_focal_loss, + ops::SigmoidFocalLossKernel, + ops::SigmoidFocalLossKernel); +REGISTER_OP_CPU_KERNEL( + sigmoid_focal_loss_grad, + ops::SigmoidFocalLossGradKernel, + ops::SigmoidFocalLossGradKernel); diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu new file mode 100644 index 00000000000..4031554aa72 --- /dev/null +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu @@ -0,0 +1,181 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "cub/cub.cuh" +#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h" +#include "paddle/fluid/operators/math.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void GPUSigmoidFocalLossForward(const T *x_data, + const int *label_data, + const int *fg_num_data, + const T gamma, const T alpha, + const int num_classes, + const int limit, T *out_data) { + CUDA_1D_KERNEL_LOOP(i, limit) { + T x = x_data[i]; + int a = i / num_classes; // current sample + int d = i % num_classes; // current class + int g = label_data[a]; // target + + // check whether the input data is positive or negative + // the target classes are in range 1-81 + // and the d is in range 0-80 + T c_pos = static_cast(g == (d + 1)); + T c_neg = static_cast((g != -1) & (g != (d + 1))); + + T fg_num = static_cast((fg_num_data[0] > 1) ? fg_num_data[0] : 1); + T s_neg = (1.0 - alpha) / fg_num; + T s_pos = alpha / fg_num; + + // p = 1. / 1. + expf(-x) + T p = 1. / (1. + real_exp(-x)); + + // (1 - p)**gamma * log(p) + T term_pos = std::pow(static_cast(1. - p), gamma) * + real_log(p > FLT_MIN ? p : FLT_MIN); + // p**gamma * log(1 - p) + T term_neg = + std::pow(p, gamma) * + (-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0)))); + + out_data[i] = 0.0; + out_data[i] += -c_pos * term_pos * s_pos; + out_data[i] += -c_neg * term_neg * s_neg; + } +} + +template +__global__ void GPUSigmoidFocalLossBackward( + const T *x_data, const int *label_data, const int *fg_num_data, + const T gamma, const T alpha, const int num_classes, const T *dout_data, + const int limit, T *dx_data) { + CUDA_1D_KERNEL_LOOP(i, limit) { + T x = x_data[i]; + T dout = dout_data[i]; + + int a = i / num_classes; // current sample + int d = i % num_classes; // current class + + T fg_num = static_cast((fg_num_data[0] > 1) ? fg_num_data[0] : 1); + T s_neg = (1.0 - alpha) / fg_num; + T s_pos = alpha / fg_num; + + int g = label_data[a]; + T c_pos = static_cast(g == (d + 1)); + T c_neg = static_cast((g != -1) & (g != (d + 1))); + + T p = 1. / (1. + real_exp(-x)); + + // (1-p)**g * (1 - p - g*p*log(p)) + T term_pos = std::pow(static_cast(1. - p), gamma) * + (1. - p - (p * gamma * real_log(p > FLT_MIN ? p : FLT_MIN))); + // (p**g) * (g*(1-p)*log(1-p) - p) + T term_neg = + std::pow(p, gamma) * + ((-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0)))) * + (1. - p) * gamma - + p); + + dx_data[i] = 0.0; + dx_data[i] += -c_pos * s_pos * term_pos; + dx_data[i] += -c_neg * s_neg * term_neg; + dx_data[i] = dx_data[i] * dout; + } +} + +template +class GPUSigmoidFocalLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + const Tensor *FgNum = context.Input("FgNum"); + Tensor *Out = context.Output("Out"); + T gamma = static_cast(context.Attr("gamma")); + T alpha = static_cast(context.Attr("alpha")); + auto x_dims = X->dims(); + int num_classes = static_cast(x_dims[1]); + auto out_data = Out->mutable_data(context.GetPlace()); + + auto &dev_ctx = context.cuda_device_context(); + + int limit = Out->numel(); + int blocks = NumBlocks(limit); + int threads = kNumCUDAThreads; + GPUSigmoidFocalLossForward<<>>( + X->data(), Labels->data(), FgNum->data(), gamma, alpha, + num_classes, limit, out_data); + } +}; + +template +class GPUSigmoidFocalLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + const Tensor *FgNum = context.Input("FgNum"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + Tensor *dX = context.Output(framework::GradVarName("X")); + auto dx_data = dX->mutable_data(context.GetPlace()); + T gamma = static_cast(context.Attr("gamma")); + T alpha = static_cast(context.Attr("alpha")); + auto x_dims = X->dims(); + int num_classes = static_cast(x_dims[1]); + + auto &dev_ctx = context.cuda_device_context(); + + int limit = dX->numel(); + int blocks = NumBlocks(limit); + int threads = kNumCUDAThreads; + GPUSigmoidFocalLossBackward<<>>( + X->data(), Labels->data(), FgNum->data(), gamma, alpha, + num_classes, dOut->data(), limit, dx_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sigmoid_focal_loss, + ops::GPUSigmoidFocalLossKernel, + ops::GPUSigmoidFocalLossKernel); +REGISTER_OP_CUDA_KERNEL( + sigmoid_focal_loss_grad, + ops::GPUSigmoidFocalLossGradKernel, + ops::GPUSigmoidFocalLossGradKernel); diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.h b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.h new file mode 100644 index 00000000000..c4d44c1456f --- /dev/null +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.h @@ -0,0 +1,128 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SigmoidFocalLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + const Tensor *FgNum = context.Input("FgNum"); + Tensor *Out = context.Output("Out"); + T gamma = static_cast(context.Attr("gamma")); + T alpha = static_cast(context.Attr("alpha")); + auto out_data = Out->mutable_data(context.GetPlace()); + int limit = Out->numel(); + auto x_data = X->data(); + auto label_data = Labels->data(); + auto fg_num_data = FgNum->data(); + auto x_dims = X->dims(); + int num_classes = static_cast(x_dims[1]); + + for (int idx = 0; idx < limit; ++idx) { + T x = x_data[idx]; + int a = idx / num_classes; // current sample + int d = idx % num_classes; // current class + int g = label_data[a]; // target + + // Check whether the input data is positive or negative + // The target classes are in range 1-81 + // and the d is in range 0-80 + T c_pos = static_cast(g == (d + 1)); + T c_neg = static_cast((g != -1) & (g != (d + 1))); + T fg_num = static_cast((fg_num_data[0] > 1) ? fg_num_data[0] : 1); + T s_neg = (1.0 - alpha) / fg_num; + T s_pos = alpha / fg_num; + + // p = 1. / 1. + expf(-x) + T p = 1. / (1. + std::exp(-x)); + + // (1 - p)**gamma * log(p) where + T term_pos = std::pow(static_cast(1. - p), gamma) * + std::log(p > FLT_MIN ? p : FLT_MIN); + // p**gamma * log(1 - p) + T term_neg = + std::pow(p, gamma) * + (-1. * x * (x >= 0) - std::log(1. + std::exp(x - 2. * x * (x >= 0)))); + out_data[idx] = 0.0; + out_data[idx] += -c_pos * term_pos * s_pos; + out_data[idx] += -c_neg * term_neg * s_neg; + } + } +}; + +template +class SigmoidFocalLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + const Tensor *FgNum = context.Input("FgNum"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + Tensor *dX = context.Output(framework::GradVarName("X")); + auto dx_data = dX->mutable_data(context.GetPlace()); + T gamma = static_cast(context.Attr("gamma")); + T alpha = static_cast(context.Attr("alpha")); + auto x_dims = X->dims(); + int num_classes = static_cast(x_dims[1]); + + int limit = dX->numel(); + auto x_data = X->data(); + auto label_data = Labels->data(); + auto fg_num_data = FgNum->data(); + auto dout_data = dOut->data(); + for (int idx = 0; idx < limit; ++idx) { + T x = x_data[idx]; + int a = idx / num_classes; // current sample + int d = idx % num_classes; // current class + + T fg_num = static_cast((fg_num_data[0] > 1) ? fg_num_data[0] : 1); + T s_neg = static_cast((1.0 - alpha) / fg_num); + T s_pos = alpha / fg_num; + int g = label_data[a]; + + T c_pos = static_cast(g == (d + 1)); + T c_neg = static_cast((g != -1) & (g != (d + 1))); + T p = 1. / (1. + std::exp(-x)); + + // (1-p)**g * (1 - p - g*p*log(p)) + T term_pos = std::pow(static_cast(1. - p), gamma) * + (1. - p - (p * gamma * std::log(p > FLT_MIN ? p : FLT_MIN))); + // (p**g) * (g*(1-p)*log(1-p) - p) + T term_neg = std::pow(p, gamma) * + ((-1. * x * (x >= 0) - + std::log(1. + std::exp(x - 2. * x * (x >= 0)))) * + (1. - p) * gamma - + p); + + dx_data[idx] = 0.0; + dx_data[idx] += -c_pos * s_pos * term_pos; + dx_data[idx] += -c_neg * s_neg * term_neg; + dx_data[idx] = dx_data[idx] * dout_data[idx]; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index fa85350adcd..b2e3aff3063 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -40,6 +40,8 @@ __all__ = [ 'ssd_loss', 'detection_map', 'rpn_target_assign', + 'retinanet_target_assign', + 'sigmoid_focal_loss', 'anchor_generator', 'roi_perspective_transform', 'generate_proposal_labels', @@ -52,12 +54,171 @@ __all__ = [ 'yolo_box', 'box_clip', 'multiclass_nms', + 'retinanet_detection_output', 'distribute_fpn_proposals', 'box_decoder_and_assign', 'collect_fpn_proposals', ] +def retinanet_target_assign(bbox_pred, + cls_logits, + anchor_box, + anchor_var, + gt_boxes, + gt_labels, + is_crowd, + im_info, + num_classes=1, + positive_overlap=0.5, + negative_overlap=0.4): + """ + **Target Assign Layer for Retinanet .** + + This layer can be, for given the Intersection-over-Union (IoU) overlap + between anchors and ground truth boxes, to assign classification and + regression targets to each anchor, these target labels are used for training + retinanet. Every anchor is assigned with a length :attr:`num_classes` + one-hot vector of classification targets, and a 4-vector of box regression + targets. The assignment rules are as followed: + + 1. Anchors are assigned to ground-truth boxes when: (i) it has the highest + IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher + than positive_overlap(0.5) with any ground-truth box. + + 2. Anchors are assigned to background when its IoU ratio is lower than + negative_overlap (0.4) for all ground-truth boxes. + + When an anchor is assigned with a ground-truth box which is the i-th category, + the i-th entry in its C vector of targets is set to 1 and all other entries + are set to 0. When an anchor is assigned with background, all entries are set + to 0. Anchors that are not assigned do not contribute to the training + objective. The regression targets are the encoded ground-truth boxes + associated with the assigned anchors. + + Args: + bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the + predicted locations of M bounding bboxes. N is the batch size, + and each bounding box has four coordinate values and the layout + is [xmin, ymin, xmax, ymax]. + cls_logits(Variable): A 3-D Tensor with shape [N, M, C] represents the + predicted confidence predictions. N is the batch size, C is the + number of classes (excluding background), M is number of bounding boxes. + anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes, + each box is represented as [xmin, ymin, xmax, ymax], + [xmin, ymin] is the left top coordinate of the anchor box, + if the input is image feature map, they are close to the origin + of the coordinate system. [xmax, ymax] is the right bottom + coordinate of the anchor box. + anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded + variances of anchors. + gt_boxes(Variable): The ground-truth bounding boxes (bboxes) are a 2D + LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth + bboxes of mini-batch input. + gt_labels(variable): The ground-truth labels are a 2D LoDTensor with + shape [Ng, 1], Ng is the total number of ground-truth labels of + mini-batch input. + is_crowd(Variable): A 1-D LoDTensor which indicates ground-truth is crowd. + im_info(Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size, + 3 is the height, width and scale. + num_classes(int32): The number of classes. + positive_overlap(float): Minimum overlap required between an anchor + and ground-truth box for the (anchor, gt box) pair to be a positive + example. + negative_overlap(float): Maximum overlap allowed between an anchor + and ground-truth box for the (anchor, gt box) pair to be a negative + examples. + + Returns: + tuple: + A tuple(predicted_scores, predicted_location, target_label, + target_bbox, bbox_inside_weight, fg_num) is returned. The + predicted_scores and predicted_location are the predicted result + of the retinanet.The target_label and target_bbox are the ground + truth, respectively. The predicted_location is a 2D Tensor with + shape [F, 4], and the shape of target_bbox is same as the shape of + the predicted_location, F is the number of the foreground + anchors. The predicted_scores is a 2D Tensor with shape + [F + B, C], and the shape of target_label is [F + B, 1], B is the + number of the background anchors, the F and B is depends on the + input of this operator. Bbox_inside_weight represents whether the + predicted location is fake foreground or not and the shape is [F, 4]. + Fg_num is the foreground number (including fake foreground) which + is needed by focal loss. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + bbox_pred = layers.data(name='bbox_pred', shape=[1, 100, 4], + append_batch_size=False, dtype='float32') + cls_logits = layers.data(name='cls_logits', shape=[1, 100, 10], + append_batch_size=False, dtype='float32') + anchor_box = layers.data(name='anchor_box', shape=[100, 4], + append_batch_size=False, dtype='float32') + anchor_var = layers.data(name='anchor_var', shape=[100, 4], + append_batch_size=False, dtype='float32') + gt_boxes = layers.data(name='gt_boxes', shape=[10, 4], + append_batch_size=False, dtype='float32') + gt_labels = layers.data(name='gt_labels', shape=[10, 1], + append_batch_size=False, dtype='float32') + is_crowd = fluid.layers.data(name='is_crowd', shape=[1], + append_batch_size=False, dtype='float32') + im_info = fluid.layers.data(name='im_infoss', shape=[1, 3], + append_batch_size=False, dtype='float32') + loc_pred, score_pred, loc_target, score_target, bbox_inside_weight, fg_num = + fluid.layers.retinanet_target_assign(bbox_pred, cls_logits, anchor_box, + anchor_var, gt_boxes, gt_labels, is_crowd, im_info, 10) + + """ + + helper = LayerHelper('retinanet_target_assign', **locals()) + # Assign target label to anchors + loc_index = helper.create_variable_for_type_inference(dtype='int32') + score_index = helper.create_variable_for_type_inference(dtype='int32') + target_label = helper.create_variable_for_type_inference(dtype='int32') + target_bbox = helper.create_variable_for_type_inference( + dtype=anchor_box.dtype) + bbox_inside_weight = helper.create_variable_for_type_inference( + dtype=anchor_box.dtype) + fg_num = helper.create_variable_for_type_inference(dtype='int32') + helper.append_op( + type="retinanet_target_assign", + inputs={ + 'Anchor': anchor_box, + 'GtBoxes': gt_boxes, + 'GtLabels': gt_labels, + 'IsCrowd': is_crowd, + 'ImInfo': im_info + }, + outputs={ + 'LocationIndex': loc_index, + 'ScoreIndex': score_index, + 'TargetLabel': target_label, + 'TargetBBox': target_bbox, + 'BBoxInsideWeight': bbox_inside_weight, + 'ForegroundNumber': fg_num + }, + attrs={ + 'positive_overlap': positive_overlap, + 'negative_overlap': negative_overlap + }) + + loc_index.stop_gradient = True + score_index.stop_gradient = True + target_label.stop_gradient = True + target_bbox.stop_gradient = True + bbox_inside_weight.stop_gradient = True + fg_num.stop_gradient = True + + cls_logits = nn.reshape(x=cls_logits, shape=(-1, num_classes)) + bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4)) + predicted_cls_logits = nn.gather(cls_logits, score_index) + predicted_bbox_pred = nn.gather(bbox_pred, loc_index) + + return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight, fg_num + + def rpn_target_assign(bbox_pred, cls_logits, anchor_box, @@ -210,6 +371,74 @@ def rpn_target_assign(bbox_pred, return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight +def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25): + """ + **Sigmoid Focal Loss Operator.** + + Focal loss is used to address the foreground-background class imbalance existed + on the training phase of one-stage detectors. This operator computes the sigmoid + value for each element in the input tensor, after which focal loss is measured. + + The focal loss is given as followed: + + .. math:: + loss_j = (-label_j * alpha * {(1 - \\sigma(x_j))}^{gamma} * \\log(\\sigma(x_j)) - + (1 - labels_j) * (1 - alpha) * {(\sigma(x_j)}^{ gamma} * \\log(1 - \\sigma(x_j))) + / fg\_num, j = 1,...,K + + We know that + + .. math:: + \\sigma(x_j) = \\frac{1}{1 + \\exp(-x_j)} + + Args: + x(Variable): A 2-D tensor with shape [N, D], where N is the batch size and D is the number + of classes (excluding background). This input is a tensor of logits computed by the + previous operator. + label(Variable): A 2-D tensor with shape [N, 1], which is the probabilistic labels. + fg_num(Variable): A 1-D tensor with shape [1], which is the number of foreground. + + gamma(float): Hyper-parameter to balance the easy and hard examples. Default value is + set to 2.0. + alpha(float): Hyper-parameter to balance the positive and negative example. Default value + is set to 0.25. + + Returns: + out(Variable): A 2-D tensor with shape [N, D], which is the focal loss. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + input = fluid.layers.data( + name='data', shape=[10,80], append_batch_size=False, dtype='float32') + label = fluid.layers.data( + name='label', shape=[10,1], append_batch_size=False, dtype='int32') + fg_num = fluid.layers.data( + name='fg_num', shape=[1], append_batch_size=False, dtype='int32') + loss = fluid.layers.sigmoid_focal_loss(x=input, + label=label, + fg_num=fg_num, + gamma=2., + alpha=0.25) + """ + + helper = LayerHelper("sigmoid_focal_loss", **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="sigmoid_focal_loss", + inputs={"X": x, + "Label": label, + "FgNum": fg_num}, + attrs={"gamma": gamma, + 'alpha': alpha}, + outputs={"Out": out}) + return out + + def detection_output(loc, scores, prior_box, @@ -2320,6 +2549,113 @@ def box_clip(input, im_info, name=None): return output +def retinanet_detection_output(bboxes, + scores, + anchors, + im_info, + score_threshold=0.05, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.3, + nms_eta=1.): + """ + **Detection Output Layer for Retinanet.** + + This operation is to get the detection results by performing following + steps: + + 1. Decode top-scoring bounding box predictions per FPN level according + to the anchor boxes. + 2. Merge top predictions from all levels and apply multi-class non + maximum suppression (NMS) on them to get the final detections. + + Args: + bboxes(List): A list of tensors from multiple FPN levels. Each + element is a 3-D Tensor with shape [N, Mi, 4] representing the + predicted locations of Mi bounding boxes. N is the batch size, + Mi is the number of bounding boxes from i-th FPN level and each + bounding box has four coordinate values and the layout is + [xmin, ymin, xmax, ymax]. + scores(List): A list of tensors from multiple FPN levels. Each + element is a 3-D Tensor with shape [N, Mi, C] representing the + predicted confidence predictions. N is the batch size, C is the + class number (excluding background), Mi is the number of bounding + boxes from i-th FPN level. For each bounding box, there are total + C scores. + anchors(List): A 2-D Tensor with shape [Mi, 4] represents the locations + of Mi anchor boxes from all FPN level. Each bounding box has four + coordinate values and the layout is [xmin, ymin, xmax, ymax]. + im_info(Variable): A 2-D LoDTensor with shape [N, 3] represents the + image information. N is the batch size, each image information + includes height, width and scale. + score_threshold(float): Threshold to filter out bounding boxes + with a confidence score. + nms_top_k(int): Maximum number of detections per FPN layer to be + kept according to the confidences before NMS. + keep_top_k(int): Number of total bounding boxes to be kept per image after + NMS step. -1 means keeping all bounding boxes after NMS step. + nms_threshold(float): The threshold to be used in NMS. + nms_eta(float): The parameter for adaptive NMS. + + Returns: + Variable: + The detection output is a LoDTensor with shape [No, 6]. + Each row has six values: [label, confidence, xmin, ymin, xmax, ymax]. + `No` is the total number of detections in this mini-batch. For each + instance, the offsets in first dimension are called LoD, the offset + number is N + 1, N is the batch size. The i-th image has + `LoD[i + 1] - LoD[i]` detected results, if it is 0, the i-th image + has no detected results. If all images have no detected results, + LoD will be set to 0, and the output tensor is empty (None). + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + bboxes = layers.data(name='bboxes', shape=[1, 21, 4], + append_batch_size=False, dtype='float32') + scores = layers.data(name='scores', shape=[1, 21, 10], + append_batch_size=False, dtype='float32') + anchors = layers.data(name='anchors', shape=[21, 4], + append_batch_size=False, dtype='float32') + im_info = layers.data(name="im_info", shape=[1, 3], + append_batch_size=False, dtype='float32') + nmsed_outs = fluid.layers.retinanet_detection_output( + bboxes=[bboxes, bboxes], + scores=[scores, scores], + anchors=[anchors, anchors], + im_info=im_info, + score_threshold=0.05, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.3, + nms_eta=1.) + """ + + helper = LayerHelper('retinanet_detection_output', **locals()) + output = helper.create_variable_for_type_inference( + dtype=helper.input_dtype('scores')) + helper.append_op( + type="retinanet_detection_output", + inputs={ + 'BBoxes': bboxes, + 'Scores': scores, + 'Anchors': anchors, + 'ImInfo': im_info + }, + attrs={ + 'score_threshold': score_threshold, + 'nms_top_k': nms_top_k, + 'nms_threshold': nms_threshold, + 'keep_top_k': keep_top_k, + 'nms_eta': 1., + }, + outputs={'Out': output}) + output.stop_gradient = True + return output + + def multiclass_nms(bboxes, scores, score_threshold, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 2d4ddb01d4d..0a35f42a642 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2018,6 +2018,110 @@ class TestBook(LayerTest): trans_std=0.1) return (out) + def test_retinanet_target_assign(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + bbox_pred = layers.data( + name='bbox_pred', + shape=[1, 100, 4], + append_batch_size=False, + dtype='float32') + cls_logits = layers.data( + name='cls_logits', + shape=[1, 100, 10], + append_batch_size=False, + dtype='float32') + anchor_box = layers.data( + name='anchor_box', + shape=[100, 4], + append_batch_size=False, + dtype='float32') + anchor_var = layers.data( + name='anchor_var', + shape=[100, 4], + append_batch_size=False, + dtype='float32') + gt_boxes = layers.data( + name='gt_boxes', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + gt_labels = layers.data( + name='gt_labels', + shape=[10, 1], + append_batch_size=False, + dtype='float32') + is_crowd = layers.data( + name='is_crowd', + shape=[1], + append_batch_size=False, + dtype='float32') + im_info = layers.data( + name='im_info', + shape=[1, 3], + append_batch_size=False, + dtype='float32') + return (layers.retinanet_target_assign( + bbox_pred, cls_logits, anchor_box, anchor_var, gt_boxes, + gt_labels, is_crowd, im_info, 10)) + + def test_sigmoid_focal_loss(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = layers.data( + name='data', + shape=[10, 80], + append_batch_size=False, + dtype='float32') + label = layers.data( + name='label', + shape=[10, 1], + append_batch_size=False, + dtype='int32') + fg_num = layers.data( + name='fg_num', + shape=[1], + append_batch_size=False, + dtype='int32') + out = fluid.layers.sigmoid_focal_loss( + x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25) + return (out) + + def test_retinanet_detection_output(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + bboxes = layers.data( + name='bboxes', + shape=[1, 21, 4], + append_batch_size=False, + dtype='float32') + scores = layers.data( + name='scores', + shape=[1, 21, 10], + append_batch_size=False, + dtype='float32') + anchors = layers.data( + name='anchors', + shape=[21, 4], + append_batch_size=False, + dtype='float32') + im_info = layers.data( + name="im_info", + shape=[1, 3], + append_batch_size=False, + dtype='float32') + nmsed_outs = layers.retinanet_detection_output( + bboxes=[bboxes, bboxes], + scores=[scores, scores], + anchors=[anchors, anchors], + im_info=im_info, + score_threshold=0.05, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.3, + nms_eta=1.) + return (nmsed_outs) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py b/python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py new file mode 100644 index 00000000000..fafc7de33bc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_retinanet_detection_output.py @@ -0,0 +1,412 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +#Licensed under the Apache License, Version 2.0 (the "License") +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import math +import copy +from op_test import OpTest +from test_anchor_generator_op import anchor_generator_in_python +from test_multiclass_nms_op import iou +from test_multiclass_nms_op import nms + + +def multiclass_nms(prediction, class_num, keep_top_k, nms_threshold): + selected_indices = {} + num_det = 0 + for c in range(class_num): + if c not in prediction.keys(): + continue + cls_dets = prediction[c] + all_scores = np.zeros(len(cls_dets)) + for i in range(all_scores.shape[0]): + all_scores[i] = cls_dets[i][4] + indices = nms(cls_dets, all_scores, 0.0, nms_threshold, -1, False, 1.0) + selected_indices[c] = indices + num_det += len(indices) + + score_index = [] + for c, indices in selected_indices.items(): + for idx in indices: + score_index.append((prediction[c][idx][4], c, idx)) + + sorted_score_index = sorted( + score_index, key=lambda tup: tup[0], reverse=True) + if keep_top_k > -1 and num_det > keep_top_k: + sorted_score_index = sorted_score_index[:keep_top_k] + num_det = keep_top_k + nmsed_outs = [] + for s, c, idx in sorted_score_index: + xmin = prediction[c][idx][0] + ymin = prediction[c][idx][1] + xmax = prediction[c][idx][2] + ymax = prediction[c][idx][3] + nmsed_outs.append([c + 1, s, xmin, ymin, xmax, ymax]) + + return nmsed_outs, num_det + + +def retinanet_detection_out(boxes_list, scores_list, anchors_list, im_info, + score_threshold, nms_threshold, nms_top_k, + keep_top_k): + class_num = scores_list[0].shape[-1] + im_height, im_width, im_scale = im_info + + num_level = len(scores_list) + prediction = {} + for lvl in range(num_level): + scores_per_level = scores_list[lvl] + scores_per_level = scores_per_level.flatten() + bboxes_per_level = boxes_list[lvl] + bboxes_per_level = bboxes_per_level.flatten() + anchors_per_level = anchors_list[lvl] + anchors_per_level = anchors_per_level.flatten() + + thresh = score_threshold if lvl < (num_level - 1) else 0.0 + selected_indices = np.argwhere(scores_per_level > thresh) + scores = scores_per_level[selected_indices] + sorted_indices = np.argsort(-scores, axis=0, kind='mergesort') + if nms_top_k > -1 and nms_top_k < sorted_indices.shape[0]: + sorted_indices = sorted_indices[:nms_top_k] + + for i in range(sorted_indices.shape[0]): + idx = selected_indices[sorted_indices[i]] + idx = idx[0][0] + a = int(idx / class_num) + c = int(idx % class_num) + box_offset = a * 4 + anchor_box_width = anchors_per_level[ + box_offset + 2] - anchors_per_level[box_offset] + 1 + anchor_box_height = anchors_per_level[ + box_offset + 3] - anchors_per_level[box_offset + 1] + 1 + anchor_box_center_x = anchors_per_level[ + box_offset] + anchor_box_width / 2 + anchor_box_center_y = anchors_per_level[box_offset + + 1] + anchor_box_height / 2 + + target_box_center_x = bboxes_per_level[ + box_offset] * anchor_box_width + anchor_box_center_x + target_box_center_y = bboxes_per_level[ + box_offset + 1] * anchor_box_height + anchor_box_center_y + target_box_width = math.exp(bboxes_per_level[box_offset + + 2]) * anchor_box_width + target_box_height = math.exp(bboxes_per_level[ + box_offset + 3]) * anchor_box_height + + pred_box_xmin = target_box_center_x - target_box_width / 2 + pred_box_ymin = target_box_center_y - target_box_height / 2 + pred_box_xmax = target_box_center_x + target_box_width / 2 - 1 + pred_box_ymax = target_box_center_y + target_box_height / 2 - 1 + + pred_box_xmin = pred_box_xmin / im_scale + pred_box_ymin = pred_box_ymin / im_scale + pred_box_xmax = pred_box_xmax / im_scale + pred_box_ymax = pred_box_ymax / im_scale + + pred_box_xmin = max( + min(pred_box_xmin, np.round(im_width / im_scale) - 1), 0.) + pred_box_ymin = max( + min(pred_box_ymin, np.round(im_height / im_scale) - 1), 0.) + pred_box_xmax = max( + min(pred_box_xmax, np.round(im_width / im_scale) - 1), 0.) + pred_box_ymax = max( + min(pred_box_ymax, np.round(im_height / im_scale) - 1), 0.) + + if c not in prediction.keys(): + prediction[c] = [] + prediction[c].append([ + pred_box_xmin, pred_box_ymin, pred_box_xmax, pred_box_ymax, + scores_per_level[idx] + ]) + + nmsed_outs, nmsed_num = multiclass_nms(prediction, class_num, keep_top_k, + nms_threshold) + return nmsed_outs, nmsed_num + + +def batched_retinanet_detection_out(boxes, scores, anchors, im_info, + score_threshold, nms_threshold, nms_top_k, + keep_top_k): + batch_size = scores[0].shape[0] + det_outs = [] + lod = [] + + for n in range(batch_size): + boxes_per_batch = [] + scores_per_batch = [] + + num_level = len(scores) + for lvl in range(num_level): + boxes_per_batch.append(boxes[lvl][n]) + scores_per_batch.append(scores[lvl][n]) + + nmsed_outs, nmsed_num = retinanet_detection_out( + boxes_per_batch, scores_per_batch, anchors, im_info[n], + score_threshold, nms_threshold, nms_top_k, keep_top_k) + lod.append(nmsed_num) + if nmsed_num == 0: + continue + + det_outs.extend(nmsed_outs) + return det_outs, lod + + +class TestRetinanetDetectionOutOp1(OpTest): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + def init_test_input(self): + anchor_num = len(self.aspect_ratios) * self.scales_per_octave + num_levels = self.max_level - self.min_level + 1 + self.scores_list = [] + self.bboxes_list = [] + self.anchors_list = [] + + for i in range(num_levels): + layer_h = self.layer_h[i] + layer_w = self.layer_w[i] + + input_feat = np.random.random((self.batch_size, self.input_channels, + layer_h, layer_w)).astype('float32') + score = np.random.random( + (self.batch_size, self.class_num * anchor_num, layer_h, + layer_w)).astype('float32') + score = np.transpose(score, [0, 2, 3, 1]) + score = score.reshape((self.batch_size, -1, self.class_num)) + box = np.random.random((self.batch_size, self.box_size * anchor_num, + layer_h, layer_w)).astype('float32') + box = np.transpose(box, [0, 2, 3, 1]) + box = box.reshape((self.batch_size, -1, self.box_size)) + anchor_sizes = [] + for octave in range(self.scales_per_octave): + anchor_sizes.append( + float(self.anchor_strides[i] * (2**octave)) / + float(self.scales_per_octave) * self.anchor_scale) + anchor, var = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=anchor_sizes, + aspect_ratios=self.aspect_ratios, + variances=[1.0, 1.0, 1.0, 1.0], + stride=[self.anchor_strides[i], self.anchor_strides[i]], + offset=0.5) + anchor = np.reshape(anchor, [-1, 4]) + self.scores_list.append(score.astype('float32')) + self.bboxes_list.append(box.astype('float32')) + self.anchors_list.append(anchor.astype('float32')) + + self.im_info = np.array([[256., 256., 1.5]]).astype( + 'float32') #im_height, im_width, scale + + def setUp(self): + self.set_argument() + self.init_test_input() + + nmsed_outs, lod = batched_retinanet_detection_out( + self.bboxes_list, self.scores_list, self.anchors_list, self.im_info, + self.score_threshold, self.nms_threshold, self.nms_top_k, + self.keep_top_k) + nmsed_outs = np.array(nmsed_outs).astype('float32') + self.op_type = 'retinanet_detection_output' + self.inputs = { + 'BBoxes': [('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]), + ('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3]), + ('b4', self.bboxes_list[4])], + 'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]), + ('s2', self.scores_list[2]), ('s3', self.scores_list[3]), + ('s4', self.scores_list[4])], + 'Anchors': + [('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]), + ('a2', self.anchors_list[2]), ('a3', self.anchors_list[3]), + ('a4', self.anchors_list[4])], + 'ImInfo': (self.im_info, [[1, ]]) + } + self.outputs = {'Out': (nmsed_outs, [lod])} + self.attrs = { + 'score_threshold': self.score_threshold, + 'nms_top_k': self.nms_top_k, + 'nms_threshold': self.nms_threshold, + 'keep_top_k': self.keep_top_k, + 'nms_eta': 1., + } + + def test_check_output(self): + self.check_output() + + +class TestRetinanetDetectionOutOp2(OpTest): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + # Here test the case there the shape of each FPN level + # is irrelevant. + self.layer_h = [1, 4, 8, 8, 16] + self.layer_w = [1, 4, 8, 8, 16] + + +class TestRetinanetDetectionOutOpNo3(TestRetinanetDetectionOutOp1): + def set_argument(self): + # Here set 2.0 to test the case there is no outputs. + # In practical use, 0.0 < score_threshold < 1.0 + self.score_threshold = 2.0 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + +class TestRetinanetDetectionOutOpNo4(TestRetinanetDetectionOutOp1): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 2 + self.max_level = 5 + self.nms_threshold = 0.3 + self.nms_top_k = 1000 + self.keep_top_k = 200 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + def setUp(self): + self.set_argument() + self.init_test_input() + + nmsed_outs, lod = batched_retinanet_detection_out( + self.bboxes_list, self.scores_list, self.anchors_list, self.im_info, + self.score_threshold, self.nms_threshold, self.nms_top_k, + self.keep_top_k) + nmsed_outs = np.array(nmsed_outs).astype('float32') + self.op_type = 'retinanet_detection_output' + self.inputs = { + 'BBoxes': + [('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]), + ('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3])], + 'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]), + ('s2', self.scores_list[2]), + ('s3', self.scores_list[3])], + 'Anchors': + [('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]), + ('a2', self.anchors_list[2]), ('a3', self.anchors_list[3])], + 'ImInfo': (self.im_info, [[1, ]]) + } + self.outputs = {'Out': (nmsed_outs, [lod])} + self.attrs = { + 'score_threshold': self.score_threshold, + 'nms_top_k': self.nms_top_k, + 'nms_threshold': self.nms_threshold, + 'keep_top_k': self.keep_top_k, + 'nms_eta': 1., + } + + def test_check_output(self): + self.check_output() + + +class TestRetinanetDetectionOutOpNo5(TestRetinanetDetectionOutOp1): + def set_argument(self): + self.score_threshold = 0.05 + self.min_level = 3 + self.max_level = 7 + self.nms_threshold = 0.3 + self.nms_top_k = 100 + self.keep_top_k = 10 + + self.scales_per_octave = 3 + self.aspect_ratios = [1.0, 2.0, 0.5] + self.anchor_scale = 4 + self.anchor_strides = [8, 16, 32, 64, 128] + + self.box_size = 4 + self.class_num = 80 + self.batch_size = 1 + self.input_channels = 20 + + self.layer_h = [] + self.layer_w = [] + num_levels = self.max_level - self.min_level + 1 + for i in range(num_levels): + self.layer_h.append(2**(num_levels - i)) + self.layer_w.append(2**(num_levels - i)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index 1a2c9bb5f43..3dba961dc9d 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -167,6 +167,105 @@ def rpn_target_assign_in_python(all_anchors, return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights +def retinanet_target_assign(anchor_by_gt_overlap, gt_labels, positive_overlap, + negative_overlap): + anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1) + anchor_to_gt_max = anchor_by_gt_overlap[np.arange( + anchor_by_gt_overlap.shape[0]), anchor_to_gt_argmax] + + gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0) + gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, np.arange( + anchor_by_gt_overlap.shape[1])] + anchors_with_max_overlap = np.where( + anchor_by_gt_overlap == gt_to_anchor_max)[0] + + labels = np.ones((anchor_by_gt_overlap.shape[0], ), dtype=np.int32) * -1 + labels[anchors_with_max_overlap] = 1 + labels[anchor_to_gt_max >= positive_overlap] = 1 + + fg_inds = np.where(labels == 1)[0] + bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32) + + bg_inds = np.where(anchor_to_gt_max < negative_overlap)[0] + enable_inds = bg_inds + + fg_fake_inds = np.array([], np.int32) + fg_value = np.array([fg_inds[0]], np.int32) + fake_num = 0 + for bg_id in enable_inds: + if bg_id in fg_inds: + fake_num += 1 + fg_fake_inds = np.hstack([fg_fake_inds, fg_value]) + labels[enable_inds] = 0 + + bbox_inside_weight[fake_num:, :] = 1 + fg_inds = np.where(labels == 1)[0] + bg_inds = np.where(labels == 0)[0] + loc_index = np.hstack([fg_fake_inds, fg_inds]) + score_index = np.hstack([fg_inds, bg_inds]) + score_index_tmp = np.hstack([fg_inds]) + labels = labels[score_index] + + gt_inds = anchor_to_gt_argmax[loc_index] + label_inds = anchor_to_gt_argmax[score_index_tmp] + labels[0:len(fg_inds)] = np.squeeze(gt_labels[label_inds]) + fg_num = len(fg_fake_inds) + len(fg_inds) + 1 + assert not np.any(labels == -1), "Wrong labels with -1" + return loc_index, score_index, labels, gt_inds, bbox_inside_weight, fg_num + + +def retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels, + is_crowd, im_info, lod, positive_overlap, + negative_overlap): + anchor_num = all_anchors.shape[0] + batch_size = len(lod) - 1 + for i in range(batch_size): + im_scale = im_info[i][2] + + inds_inside = np.arange(all_anchors.shape[0]) + inside_anchors = all_anchors + b, e = lod[i], lod[i + 1] + gt_boxes_slice = gt_boxes[b:e, :] * im_scale + gt_labels_slice = gt_labels[b:e, :] + is_crowd_slice = is_crowd[b:e] + + not_crowd_inds = np.where(is_crowd_slice == 0)[0] + gt_boxes_slice = gt_boxes_slice[not_crowd_inds] + gt_labels_slice = gt_labels_slice[not_crowd_inds] + iou = _bbox_overlaps(inside_anchors, gt_boxes_slice) + + loc_inds, score_inds, labels, gt_inds, bbox_inside_weight, fg_num = \ + retinanet_target_assign(iou, gt_labels_slice, + positive_overlap, negative_overlap) + # unmap to all anchor + loc_inds = inds_inside[loc_inds] + score_inds = inds_inside[score_inds] + + sampled_gt = gt_boxes_slice[gt_inds] + sampled_anchor = all_anchors[loc_inds] + box_deltas = _box_to_delta(sampled_anchor, sampled_gt, [1., 1., 1., 1.]) + + if i == 0: + loc_indexes = loc_inds + score_indexes = score_inds + tgt_labels = labels + tgt_bboxes = box_deltas + bbox_inside_weights = bbox_inside_weight + fg_nums = [[fg_num]] + else: + loc_indexes = np.concatenate( + [loc_indexes, loc_inds + i * anchor_num]) + score_indexes = np.concatenate( + [score_indexes, score_inds + i * anchor_num]) + tgt_labels = np.concatenate([tgt_labels, labels]) + tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) + bbox_inside_weights = np.vstack([bbox_inside_weights, \ + bbox_inside_weight]) + fg_nums = np.concatenate([fg_nums, [[fg_num]]]) + + return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights, fg_nums + + class TestRpnTargetAssignOp(OpTest): def setUp(self): n, c, h, w = 2, 4, 14, 14 @@ -234,5 +333,65 @@ class TestRpnTargetAssignOp(OpTest): self.check_output() +class TestRetinanetTargetAssignOp(OpTest): + def setUp(self): + n, c, h, w = 2, 4, 14, 14 + all_anchors = get_anchor(n, c, h, w) + gt_num = 10 + all_anchors = all_anchors.reshape(-1, 4) + anchor_num = all_anchors.shape[0] + + images_shape = [[64, 64], [64, 64]] + groundtruth, lod = _generate_groundtruth(images_shape, 3, 4) + lod = [0, 4, 8] + + im_info = np.ones((len(images_shape), 3)).astype(np.float32) + for i in range(len(images_shape)): + im_info[i, 0] = images_shape[i][0] + im_info[i, 1] = images_shape[i][1] + im_info[i, 2] = 0.8 #scale + gt_boxes = np.vstack([v['boxes'] for v in groundtruth]) + is_crowd = np.hstack([v['is_crowd'] for v in groundtruth]) + gt_labels = np.vstack([ + v['gt_classes'].reshape(len(v['gt_classes']), 1) + for v in groundtruth + ]) + gt_labels = gt_labels.reshape(len(gt_labels), 1) + all_anchors = all_anchors.astype('float32') + gt_boxes = gt_boxes.astype('float32') + gt_labels = gt_labels.astype('int32') + + positive_overlap = 0.5 + negative_overlap = 0.4 + + loc_index, score_index, tgt_bbox, labels, bbox_inside_weights, fg_num = \ + retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels, is_crowd, + im_info, lod, positive_overlap, negative_overlap) + labels = labels[:, np.newaxis] + self.op_type = "retinanet_target_assign" + self.inputs = { + 'Anchor': all_anchors, + 'GtBoxes': (gt_boxes, [[4, 4]]), + 'GtLabels': (gt_labels, [[4, 4]]), + 'IsCrowd': (is_crowd, [[4, 4]]), + 'ImInfo': (im_info, [[1, 1]]) + } + self.attrs = { + 'positive_overlap': positive_overlap, + 'negative_overlap': negative_overlap + } + self.outputs = { + 'LocationIndex': loc_index.astype('int32'), + 'ScoreIndex': score_index.astype('int32'), + 'TargetBBox': tgt_bbox.astype('float32'), + 'TargetLabel': labels.astype('int32'), + 'BBoxInsideWeight': bbox_inside_weights.astype('float32'), + 'ForegroundNumber': fg_num.astype('int32') + } + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss_op.py new file mode 100644 index 00000000000..0e846521d0a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss_op.py @@ -0,0 +1,132 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import copy +from op_test import OpTest +from paddle.fluid import core + + +def sigmoid_focal_loss_forward(x_data, label_data, fg_num_data, gamma, alpha, + num_classes): + x_data_t = copy.deepcopy(x_data) + out_data = copy.deepcopy(x_data) + x_width = len(x_data) + x_height = len(x_data[0, :]) + x_data_t = x_data_t.flatten() + out_data = out_data.flatten() + for idx in range(len(x_data_t)): + x = x_data_t[idx] + a = int(idx / num_classes) + d = int(idx % num_classes) + label = label_data[a] + c_pos = float((int(label) == int(d + 1))) + c_neg = float(((int(label) != -1) & (int(label) != (d + 1)))) + fg_num = max(fg_num_data, 1) + z_neg = (1.0 - alpha) / fg_num + z_pos = alpha / fg_num + + p = 1. / (1. + math.exp(-x)) + FLT_MIN = 1.175494351e-38 + term_pos = math.pow((1. - p), gamma) * math.log(max(FLT_MIN, p)) + term_neg = math.pow(p, gamma) * ( + -1. * x * (x >= 0) - math.log(1. + math.exp(x - 2. * x * (x >= 0)))) + out_data[idx] = 0.0 + out_data[idx] += -c_pos * term_pos * z_pos + out_data[idx] += -c_neg * term_neg * z_neg + + out_data = out_data.reshape(x_width, x_height) + return out_data + + +class TestSigmoidFocalLossOp1(OpTest): + def set_argument(self): + self.num_anchors = 10 + self.num_classes = 10 + self.gamma = 2.0 + self.alpha = 0.25 + + def setUp(self): + self.set_argument() + + dims = (self.num_anchors, self.num_classes) + X = np.random.standard_normal(dims).astype("float32") + L = np.random.randint(0, self.num_classes + 1, + (dims[0], 1)).astype("int32") + F = np.zeros(1) + F[0] = len(np.where(L > 0)[0]) + F = F.astype("int32") + + self.op_type = "sigmoid_focal_loss" + self.inputs = { + 'X': X, + 'Label': L, + 'FgNum': F, + } + self.attrs = { + 'gamma': self.gamma, + 'alpha': self.alpha, + } + loss = sigmoid_focal_loss_forward( + self.inputs['X'], self.inputs['Label'], self.inputs['FgNum'], + self.gamma, self.alpha, self.num_classes) + self.outputs = {'Out': loss.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSigmoidFocalLossOp2(TestSigmoidFocalLossOp1): + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=2e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=0.002) + + +class TestSigmoidFocalLossOp3(TestSigmoidFocalLossOp1): + def set_argument(self): + self.num_anchors = 200 + self.num_classes = 10 + self.gamma = 1.0 + self.alpha = 0.5 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSigmoidFocalLossOp4(TestSigmoidFocalLossOp3): + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=2e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=0.002) + + +if __name__ == '__main__': + unittest.main() -- GitLab