未验证 提交 3305045c 编写于 作者: F FlyingQianMM 提交者: GitHub

Cherry pick retinanet_target_assign_op(#17893), sigmoid_focal_loss_op(#17895)...

Cherry pick retinanet_target_assign_op(#17893), sigmoid_focal_loss_op(#17895) and retinanet_detection_output_op(#17896) for supporting retinanet (#18141)

* test=release/1.5
Fix conflicts in test_layers.py when adding target assign operator for supporting retinanet. Cherry pick #17893

* test=release/1.5
Add sigmoid focal loss operator for supporting retinanet. Cherry pick #17895

* test=release/1.5
Add detection output operator for supporting retinanet. Cherry pick #17896

* test=release/1.5
fix wrong code style in test_layers.py when cherry pick retinanet_target_assign #17893

* test=release/1.5
Fix type error of std::pow in sigmoid_focal_loss. Cherry pick #17895
上级 7c7afef7
...@@ -348,6 +348,8 @@ paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box' ...@@ -348,6 +348,8 @@ paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box'
paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee')) paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee'))
paddle.fluid.layers.detection_map (ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')), ('document', '1467d91b50c22cd52103b4aa1ee9d0a1')) paddle.fluid.layers.detection_map (ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')), ('document', '1467d91b50c22cd52103b4aa1ee9d0a1'))
paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801')) paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801'))
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
paddle.fluid.layers.sigmoid_focal_loss (ArgSpec(args=['x', 'label', 'fg_num', 'gamma', 'alpha'], varargs=None, keywords=None, defaults=(2, 0.25)), ('document', 'aeac6aae100173b3fc7f102cf3023a3d'))
paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910')) paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910'))
...@@ -360,6 +362,7 @@ paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'ancho ...@@ -360,6 +362,7 @@ paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'ancho
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f332fb8c5bb581bd1a6b5be450a99990')) paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f332fb8c5bb581bd1a6b5be450a99990'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '04384378ff00a42ade8fabd52e27cbc5')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '04384378ff00a42ade8fabd52e27cbc5'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.retinanet_detection_output (ArgSpec(args=['bboxes', 'scores', 'anchors', 'im_info', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0.05, 1000, 100, 0.3, 1.0)), ('document', '078d28607ce261a0cba2b965a79f6bb8'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5')) paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5'))
paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace')) paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace'))
......
...@@ -35,6 +35,8 @@ detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) ...@@ -35,6 +35,8 @@ detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
if(WITH_GPU) if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class RetinanetDetectionOutputOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(
ctx->Inputs("BBoxes").size(), 1UL,
"Input(BBoxes) of RetinanetDetectionOutput should not be null.");
PADDLE_ENFORCE_GE(
ctx->Inputs("Scores").size(), 1UL,
"Input(Scores) of RetinanetDetectionOutput should not be null.");
PADDLE_ENFORCE_GE(
ctx->Inputs("Anchors").size(), 1UL,
"Input(Anchors) of RetinanetDetectionOutput should not be null.");
PADDLE_ENFORCE_EQ(
ctx->Inputs("BBoxes").size(), ctx->Inputs("Scores").size(),
"Input tensors(BBoxes and Scores) should have the same size.");
PADDLE_ENFORCE_EQ(
ctx->Inputs("BBoxes").size(), ctx->Inputs("Anchors").size(),
"Input tensors(BBoxes and Anchors) should have the same size.");
PADDLE_ENFORCE(
ctx->HasInput("ImInfo"),
"Input(ImInfo) of RetinanetDetectionOutput should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of RetinanetDetectionOutput should not be null.");
auto bboxes_dims = ctx->GetInputsDim("BBoxes");
auto scores_dims = ctx->GetInputsDim("Scores");
auto anchors_dims = ctx->GetInputsDim("Anchors");
auto im_info_dims = ctx->GetInputDim("ImInfo");
const size_t b_n = bboxes_dims.size();
PADDLE_ENFORCE_GT(b_n, 0, "Input bbox tensors count should > 0.");
const size_t s_n = scores_dims.size();
PADDLE_ENFORCE_GT(s_n, 0, "Input score tensors count should > 0.");
const size_t a_n = anchors_dims.size();
PADDLE_ENFORCE_GT(a_n, 0, "Input anchor tensors count should > 0.");
auto bbox_dims = bboxes_dims[0];
auto score_dims = scores_dims[0];
auto anchor_dims = anchors_dims[0];
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(score_dims.size(), 3,
"The rank of Input(Scores) must be 3");
PADDLE_ENFORCE_EQ(bbox_dims.size(), 3,
"The rank of Input(BBoxes) must be 3");
PADDLE_ENFORCE_EQ(anchor_dims.size(), 2,
"The rank of Input(Anchors) must be 2");
PADDLE_ENFORCE(bbox_dims[2] == 4,
"The last dimension of Input(BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]");
PADDLE_ENFORCE_EQ(bbox_dims[1], score_dims[1],
"The 2nd dimension of Input(BBoxes) must be equal to "
"2nd dimension of Input(Scores), which represents the "
"number of the predicted boxes.");
PADDLE_ENFORCE_EQ(anchor_dims[0], bbox_dims[1],
"The 1st dimension of Input(Anchors) must be equal to "
"2nd dimension of Input(BBoxes), which represents the "
"number of the predicted boxes.");
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(ImInfo) must be 2.");
}
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.
ctx->SetOutputDim("Out", {bbox_dims[1], bbox_dims[2] + 2});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]);
return framework::OpKernelType(input_data_type,
platform::CPUPlace()); // ctx.GetPlace());
}
};
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
bool SortScoreTwoPairDescend(const std::pair<float, std::pair<T, T>>& pair1,
const std::pair<float, std::pair<T, T>>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static inline void GetMaxScoreIndex(
const std::vector<T>& scores, const T threshold, int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static inline T BBoxArea(const std::vector<T>& box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const std::vector<T>& box1,
const std::vector<T>& box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T>
class RetinanetDetectionOutputKernel : public framework::OpKernel<T> {
public:
void NMSFast(const std::vector<std::vector<T>>& cls_dets,
const T nms_threshold, const T eta,
std::vector<int>* selected_indices) const {
int64_t num_boxes = cls_dets.size();
std::vector<std::pair<T, int>> sorted_indices;
for (int64_t i = 0; i < num_boxes; ++i) {
sorted_indices.push_back(std::make_pair(cls_dets[i][4], i));
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
SortScorePairDescend<int>);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = T(0.);
overlap = JaccardOverlap<T>(cls_dets[idx], cls_dets[kept_idx], false);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
void DeltaScoreToPrediction(
const std::vector<T>& bboxes_data, const std::vector<T>& anchors_data,
T im_height, T im_width, T im_scale, int class_num,
const std::vector<std::pair<T, int>>& sorted_indices,
std::map<int, std::vector<std::vector<T>>>* preds) const {
im_height = static_cast<T>(round(im_height / im_scale));
im_width = static_cast<T>(round(im_width / im_scale));
T zero(0);
int i = 0;
for (const auto& it : sorted_indices) {
T score = it.first;
int idx = it.second;
int a = idx / class_num;
int c = idx % class_num;
int box_offset = a * 4;
T anchor_box_width =
anchors_data[box_offset + 2] - anchors_data[box_offset] + 1;
T anchor_box_height =
anchors_data[box_offset + 3] - anchors_data[box_offset + 1] + 1;
T anchor_box_center_x = anchors_data[box_offset] + anchor_box_width / 2;
T anchor_box_center_y =
anchors_data[box_offset + 1] + anchor_box_height / 2;
T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0;
target_box_center_x =
bboxes_data[box_offset] * anchor_box_width + anchor_box_center_x;
target_box_center_y =
bboxes_data[box_offset + 1] * anchor_box_height + anchor_box_center_y;
target_box_width =
std::exp(bboxes_data[box_offset + 2]) * anchor_box_width;
target_box_height =
std::exp(bboxes_data[box_offset + 3]) * anchor_box_height;
T pred_box_xmin = target_box_center_x - target_box_width / 2;
T pred_box_ymin = target_box_center_y - target_box_height / 2;
T pred_box_xmax = target_box_center_x + target_box_width / 2 - 1;
T pred_box_ymax = target_box_center_y + target_box_height / 2 - 1;
pred_box_xmin = pred_box_xmin / im_scale;
pred_box_ymin = pred_box_ymin / im_scale;
pred_box_xmax = pred_box_xmax / im_scale;
pred_box_ymax = pred_box_ymax / im_scale;
pred_box_xmin = std::max(std::min(pred_box_xmin, im_width - 1), zero);
pred_box_ymin = std::max(std::min(pred_box_ymin, im_height - 1), zero);
pred_box_xmax = std::max(std::min(pred_box_xmax, im_width - 1), zero);
pred_box_ymax = std::max(std::min(pred_box_ymax, im_height - 1), zero);
std::vector<T> one_pred;
one_pred.push_back(pred_box_xmin);
one_pred.push_back(pred_box_ymin);
one_pred.push_back(pred_box_xmax);
one_pred.push_back(pred_box_ymax);
one_pred.push_back(score);
(*preds)[c].push_back(one_pred);
i++;
}
}
void MultiClassNMS(const std::map<int, std::vector<std::vector<T>>>& preds,
int class_num, const int keep_top_k, const T nms_threshold,
const T nms_eta, std::vector<std::vector<T>>* nmsed_out,
int* num_nmsed_out) const {
std::map<int, std::vector<int>> indices;
int num_det = 0;
for (int c = 0; c < class_num; ++c) {
if (static_cast<bool>(preds.count(c))) {
const std::vector<std::vector<T>> cls_dets = preds.at(c);
NMSFast(cls_dets, nms_threshold, nms_eta, &(indices[c]));
num_det += indices[c].size();
}
}
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : indices) {
int label = it.first;
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(std::make_pair(preds.at(label)[idx][4],
std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScoreTwoPairDescend<int>);
if (num_det > keep_top_k) {
score_index_pairs.resize(keep_top_k);
}
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (const auto& it : score_index_pairs) {
int label = it.second.first;
int idx = it.second.second;
std::vector<T> one_pred;
one_pred.push_back(label);
one_pred.push_back(preds.at(label)[idx][4]);
one_pred.push_back(preds.at(label)[idx][0]);
one_pred.push_back(preds.at(label)[idx][1]);
one_pred.push_back(preds.at(label)[idx][2]);
one_pred.push_back(preds.at(label)[idx][3]);
nmsed_out->push_back(one_pred);
}
*num_nmsed_out = (num_det > keep_top_k ? keep_top_k : num_det);
}
void RetinanetDetectionOutput(const framework::ExecutionContext& ctx,
const std::vector<Tensor>& scores,
const std::vector<Tensor>& bboxes,
const std::vector<Tensor>& anchors,
const Tensor& im_info,
std::vector<std::vector<T>>* nmsed_out,
int* num_nmsed_out) const {
int64_t nms_top_k = ctx.Attr<int>("nms_top_k");
int64_t keep_top_k = ctx.Attr<int>("keep_top_k");
T nms_threshold = static_cast<T>(ctx.Attr<float>("nms_threshold"));
T nms_eta = static_cast<T>(ctx.Attr<float>("nms_eta"));
T score_threshold = static_cast<T>(ctx.Attr<float>("score_threshold"));
int64_t class_num = scores[0].dims()[1];
std::map<int, std::vector<std::vector<T>>> preds;
for (size_t l = 0; l < scores.size(); ++l) {
// Fetch per level score
Tensor scores_per_level = scores[l];
// Fetch per level bbox
Tensor bboxes_per_level = bboxes[l];
// Fetch per level anchor
Tensor anchors_per_level = anchors[l];
int64_t scores_num = scores_per_level.numel();
int64_t bboxes_num = bboxes_per_level.numel();
std::vector<T> scores_data(scores_num);
std::vector<T> bboxes_data(bboxes_num);
std::vector<T> anchors_data(bboxes_num);
std::copy_n(scores_per_level.data<T>(), scores_num, scores_data.begin());
std::copy_n(bboxes_per_level.data<T>(), bboxes_num, bboxes_data.begin());
std::copy_n(anchors_per_level.data<T>(), bboxes_num,
anchors_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
// For the highest level, we take the threshold 0.0
T threshold = (l < (scores.size() - 1) ? score_threshold : 0.0);
GetMaxScoreIndex(scores_data, threshold, nms_top_k, &sorted_indices);
auto* im_info_data = im_info.data<T>();
auto im_height = im_info_data[0];
auto im_width = im_info_data[1];
auto im_scale = im_info_data[2];
DeltaScoreToPrediction(bboxes_data, anchors_data, im_height, im_width,
im_scale, class_num, sorted_indices, &preds);
}
MultiClassNMS(preds, class_num, keep_top_k, nms_threshold, nms_eta,
nmsed_out, num_nmsed_out);
}
void MultiClassOutput(const platform::DeviceContext& ctx,
const std::vector<std::vector<T>>& nmsed_out,
Tensor* outs) const {
auto* odata = outs->data<T>();
int count = 0;
int64_t out_dim = 6;
for (size_t i = 0; i < nmsed_out.size(); ++i) {
odata[count * out_dim] = nmsed_out[i][0] + 1; // label
odata[count * out_dim + 1] = nmsed_out[i][1]; // score
odata[count * out_dim + 2] = nmsed_out[i][2]; // xmin
odata[count * out_dim + 3] = nmsed_out[i][3]; // xmin
odata[count * out_dim + 4] = nmsed_out[i][4]; // xmin
odata[count * out_dim + 5] = nmsed_out[i][5]; // xmin
count++;
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto boxes = ctx.MultiInput<Tensor>("BBoxes");
auto scores = ctx.MultiInput<Tensor>("Scores");
auto anchors = ctx.MultiInput<Tensor>("Anchors");
auto* im_info = ctx.Input<LoDTensor>("ImInfo");
auto* outs = ctx.Output<LoDTensor>("Out");
std::vector<Tensor> boxes_list(boxes.size());
std::vector<Tensor> scores_list(scores.size());
std::vector<Tensor> anchors_list(anchors.size());
for (size_t j = 0; j < boxes_list.size(); ++j) {
boxes_list[j] = *boxes[j];
scores_list[j] = *scores[j];
anchors_list[j] = *anchors[j];
}
auto score_dims = scores_list[0].dims();
int64_t batch_size = score_dims[0];
auto box_dims = boxes_list[0].dims();
int64_t box_dim = box_dims[2];
int64_t out_dim = box_dim + 2;
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
std::vector<std::vector<std::vector<T>>> all_nmsed_out;
std::vector<size_t> batch_starts = {0};
for (int i = 0; i < batch_size; ++i) {
int num_nmsed_out = 0;
std::vector<Tensor> box_per_batch_list(boxes_list.size());
std::vector<Tensor> score_per_batch_list(scores_list.size());
for (size_t j = 0; j < boxes_list.size(); ++j) {
auto score_dims = scores_list[j].dims();
score_per_batch_list[j] = scores_list[j].Slice(i, i + 1);
score_per_batch_list[j].Resize({score_dims[1], score_dims[2]});
box_per_batch_list[j] = boxes_list[j].Slice(i, i + 1);
box_per_batch_list[j].Resize({score_dims[1], box_dim});
}
Tensor im_info_slice = im_info->Slice(i, i + 1);
std::vector<std::vector<T>> nmsed_out;
RetinanetDetectionOutput(ctx, score_per_batch_list, box_per_batch_list,
anchors_list, im_info_slice, &nmsed_out,
&num_nmsed_out);
all_nmsed_out.push_back(nmsed_out);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
int num_kept = batch_starts.back();
if (num_kept == 0) {
outs->Resize({0, out_dim});
} else {
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
MultiClassOutput(dev_ctx, all_nmsed_out[i], &out);
}
}
}
framework::LoD lod;
lod.emplace_back(batch_starts);
outs->set_lod(lod);
}
};
class RetinanetDetectionOutputOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("BBoxes",
"(List) A list of tensors from multiple FPN levels. Each "
"element is a 3-D Tensor with shape [N, Mi, 4] represents the "
"predicted locations of Mi bounding boxes, N is the batch size. "
"Mi is the number of bounding boxes from i-th FPN level. Each "
"bounding box has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax].")
.AsDuplicable();
AddInput("Scores",
"(List) A list of tensors from multiple FPN levels. Each "
"element is a 3-D Tensor with shape [N, Mi, C] represents the "
"predicted confidence from its FPN level. N is the batch size, "
"C is the class number (excluding background), Mi is the number "
"of bounding boxes from i-th FPN level. For each bounding box, "
"there are total C scores.")
.AsDuplicable();
AddInput("Anchors",
"(List) A list of tensors from multiple FPN levels. Each"
"element is a 2-D Tensor with shape [Mi, 4] represents the "
"locations of Mi anchor boxes from i-th FPN level. Each "
"bounding box has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax].")
.AsDuplicable();
AddInput("ImInfo",
"(LoDTensor) A 2-D LoDTensor with shape [N, 3] represents the "
"image information. N is the batch size, each image information "
"includes height, width and scale.");
AddAttr<float>("score_threshold",
"(float) "
"Threshold to filter out bounding boxes with a confidence "
"score.");
AddAttr<int>("nms_top_k",
"(int64_t) "
"Maximum number of detections per FPN layer to be kept "
"according to the confidence before NMS.");
AddAttr<float>("nms_threshold",
"(float) "
"The threshold to be used in NMS.");
AddAttr<float>("nms_eta",
"(float) "
"The parameter for adaptive NMS.");
AddAttr<int>(
"keep_top_k",
"(int64_t) "
"Number of total bounding boxes to be kept per image after NMS "
"step.");
AddOutput("Out",
"(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax]"
"No is the total number of detections in this mini-batch."
"For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected bbox.");
AddComment(R"DOC(
This operator is to decode boxes and scores from each FPN layer and do
multi-class non maximum suppression (NMS) on merged predictions.
Top-scoring predictions per FPN layer are decoded with the anchor
information. This operator greedily selects a subset of detection bounding
boxes from each FPN layer that have high scores larger than score_threshold,
if providing this threshold, then selects the largest nms_top_k confidences
scores per FPN layer, if nms_top_k is larger than -1.
The decoding schema is described below:
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the predicted box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
anchor's center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the anchor box and `ox`, `oy`, `ow`, `oh` denote the
decoded coordinates, width and height.
Then the top decoded prediction from all levels are merged followed by NMS.
In the NMS step, this operator prunes away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
After NMS step, at most keep_top_k number of total bounding boxes are to be kept
per image if keep_top_k is larger than -1.
This operator support multi-class and batched inputs. It applying NMS
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bounding box for this image. If there is no detected boxes
for all images, all the elements in LoD are set to 0, and the output tensor is
empty (None).
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(retinanet_detection_output, ops::RetinanetDetectionOutputOp,
ops::RetinanetDetectionOutputOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(retinanet_detection_output,
ops::RetinanetDetectionOutputKernel<float>,
ops::RetinanetDetectionOutputKernel<double>);
...@@ -202,21 +202,32 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -202,21 +202,32 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
} }
// Reservoir Sampling // Reservoir Sampling
int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im); int fg_num = 0;
ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); if (rpn_fg_fraction > 0 && rpn_batch_size_per_im > 0) {
fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im);
ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random);
} else {
fg_num = static_cast<int>(fg_inds_fake.size());
}
int fg_fake_num = static_cast<int>(fg_inds_fake.size()); int fg_fake_num = static_cast<int>(fg_inds_fake.size());
for (int64_t i = 0; i < fg_fake_num; ++i) { for (int64_t i = 0; i < fg_fake_num; ++i) {
target_label[fg_inds_fake[i]] = 1; target_label[fg_inds_fake[i]] = 1;
} }
int bg_num = rpn_batch_size_per_im - fg_fake_num;
for (int64_t i = 0; i < anchor_num; ++i) { for (int64_t i = 0; i < anchor_num; ++i) {
if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) {
bg_inds_fake.push_back(i); bg_inds_fake.push_back(i);
} }
} }
ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); int bg_num = 0;
bg_num = static_cast<int>(bg_inds_fake.size()); if (rpn_fg_fraction > 0 && rpn_batch_size_per_im > 0) {
bg_num = rpn_batch_size_per_im - fg_fake_num;
ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random);
bg_num = static_cast<int>(bg_inds_fake.size());
} else {
bg_num = static_cast<int>(bg_inds_fake.size());
}
int fake_num = 0; int fake_num = 0;
for (int64_t i = 0; i < bg_num; ++i) { for (int64_t i = 0; i < bg_num; ++i) {
// fg fake found // fg fake found
...@@ -492,9 +503,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -492,9 +503,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Anchor", AddInput("Anchor",
"(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4].");
AddInput("GtBoxes", AddInput("GtBoxes",
"(LoDTensor) input groud-truth bbox with shape [K, 4]."); "(LoDTensor) input ground-truth bbox with shape [K, 4].");
AddInput("IsCrowd", AddInput("IsCrowd",
"(LoDTensor) input which indicates groud-truth is crowd."); "(LoDTensor) input which indicates ground-truth is crowd.");
AddInput("ImInfo", AddInput("ImInfo",
"(LoDTensor) input image information with shape [N, 3]. " "(LoDTensor) input image information with shape [N, 3]. "
"N is the batch size, each image information includes height, " "N is the batch size, each image information includes height, "
...@@ -536,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -536,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"ScoreIndex", "ScoreIndex",
"(Tensor), The indexes of foreground and background anchors in all " "(Tensor), The indexes of foreground and background anchors in all "
"RPN anchors(The rest anchors are ignored). The shape of the " "RPN anchors(The rest anchors are ignored). The shape of the "
"ScoreIndex is [F + B], F and B are sampled foreground and backgroud " "ScoreIndex is [F + B], F and B are sampled foreground and background "
" number."); " number.");
AddOutput("TargetBBox", AddOutput("TargetBBox",
"(Tensor), The target bbox deltas with shape " "(Tensor), The target bbox deltas with shape "
...@@ -544,7 +555,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -544,7 +555,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput( AddOutput(
"TargetLabel", "TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape " "(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number."); "[F + B, 1], F and B are sampled foreground and background number.");
AddOutput("BBoxInsideWeight", AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape " "(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number."); "[F, 4], F is the sampled foreground number.");
...@@ -573,6 +584,440 @@ negative do not contribute to the training objective. ...@@ -573,6 +584,440 @@ negative do not contribute to the training objective.
} }
}; };
class RetinanetTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Anchor",
"(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4].");
AddInput("GtBoxes",
"(LoDTensor) input ground-truth bbox with shape [K, 4].");
AddInput("GtLabels",
"(LoDTensor) input ground-truth label with shape [K, 1].");
AddInput("IsCrowd",
"(LoDTensor) input which indicates ground-truth is crowd.");
AddInput("ImInfo",
"(LoDTensor) input image information with shape [N, 3]. "
"N is the batch size, each image information includes height, "
"width and scale.");
AddAttr<float>(
"positive_overlap",
"Minimum overlap required between an anchor and ground-truth "
"box for the (anchor, gt box) pair to be a positive example.")
.SetDefault(0.5);
AddAttr<float>(
"negative_overlap",
"Maximum overlap allowed between an anchor and ground-truth "
"box for the (anchor, gt box) pair to be a negative examples.")
.SetDefault(0.4);
AddOutput(
"LocationIndex",
"(Tensor), The indexes of foreground anchors in all anchors, the "
"shape of the LocationIndex is [F], F depends on the value of input "
"tensor and attributes.");
AddOutput(
"ScoreIndex",
"(Tensor), The indexes of foreground and background anchors in all "
"RPN anchors(The rest anchors are ignored). The shape of the "
"ScoreIndex is [F + B], F and B are foreground and background "
" number.");
AddOutput("TargetBBox",
"(Tensor), The target bbox deltas with shape "
"[F, 4], F is the foreground number.");
AddOutput("TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are foreground and background number.");
AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape "
"[F, 4], F is the foreground number.");
AddOutput("ForegroundNumber",
"(Tensor), The foreground number. "
"[1, 1].");
AddComment(R"DOC(
This layer can be, for given the Intersection-over-Union (IoU) overlap
between anchors and ground truth boxes, to assign classification and
regression targets to each anchor, these target labels are used for
train retinanet.
Every anchor is assigned with a length C one-hot vector of
classification targets, and a 4-vector of box regression targets,
where C is the class number. The assignment rules are as followed:
1. Anchors are assigned to ground-truth boxes when: (i) it has the highest
IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher
than positive_overlap(0.5) with any ground-truth box.
2. Anchors are assigned to background when its IoU ratio is lower than
negative_overlap (0.4) for all ground-truth boxes.
When an anchor is assigned with a ground-truth box which is the i-th category,
the i-th entry in its C vector of targets is set to 1 and all other entries
are set to 0. When an anchor is assigned with background, all entries are set
to 0. Anchors that are not assigned do not contribute to the training
objective. The regression targets are the encoded ground-truth boxes
associated with the assigned anchors.
)DOC");
}
};
class RetinanetTargetAssignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Anchor"),
"Input(Anchor) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasInput("GtBoxes"),
"Input(GtBoxes) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasInput("GtLabels"),
"Input(GtLabels) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasInput("IsCrowd"),
"Input(Anchor) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasInput("ImInfo"),
"Input(ImInfo) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("LocationIndex"),
"Output(LocationIndex) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("ScoreIndex"),
"Output(ScoreIndex) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("TargetLabel"),
"Output(TargetLabel) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("TargetBBox"),
"Output(TargetBBox) of RetinanetTargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("BBoxInsideWeight"),
"Output(BBoxInsideWeight) of RetinanetTargetAssignOp should "
"not be null");
PADDLE_ENFORCE(ctx->HasOutput("ForegroundNumber"),
"Output(ForegroundNumber) of RetinanetTargetAssignOp should "
"not be null");
auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
auto gt_labels_dims = ctx->GetInputDim("GtLabels");
auto im_info_dims = ctx->GetInputDim("ImInfo");
PADDLE_ENFORCE_EQ(anchor_dims.size(), 2,
"The rank of Input(Anchor) must be 2.");
PADDLE_ENFORCE_EQ(gt_boxes_dims.size(), 2,
"The rank of Input(GtBoxes) must be 2.");
PADDLE_ENFORCE_EQ(gt_labels_dims.size(), 2,
"The rank of Input(GtLabels) must be 2.");
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(ImInfo) must be 2.");
ctx->SetOutputDim("LocationIndex", {gt_labels_dims[0]});
ctx->SetOutputDim("ScoreIndex", {gt_labels_dims[0]});
ctx->SetOutputDim("TargetBBox", {gt_labels_dims[0], 4});
ctx->SetOutputDim("TargetLabel", {gt_labels_dims[0], 1});
ctx->SetOutputDim("BBoxInsideWeight", {gt_labels_dims[0], 4});
ctx->SetOutputDim("ForegroundNumber", {gt_labels_dims[0], 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Anchor")->type(),
platform::CPUPlace());
}
};
template <typename T>
std::vector<Tensor> FilterCrowdGtBoxLabel(
const platform::CPUDeviceContext& context, Tensor* gt_boxes,
Tensor* gt_labels, Tensor* is_crowd) {
int gt_num = gt_boxes->dims()[0];
std::vector<int> not_crowd_inds;
auto* is_crowd_data = is_crowd->data<int>();
for (int i = 0; i < gt_num; ++i) {
if (is_crowd_data[i] == 0) {
not_crowd_inds.emplace_back(i);
}
}
int ncrowd_num = not_crowd_inds.size();
Tensor ncrowd_gt_boxes, ncrowd_gt_labels;
T* ncrowd_gt_boxes_data =
ncrowd_gt_boxes.mutable_data<T>({ncrowd_num, 4}, context.GetPlace());
int* ncrowd_gt_labels_data =
ncrowd_gt_labels.mutable_data<int>({ncrowd_num, 1}, context.GetPlace());
Gather<T>(gt_boxes->data<T>(), 4, not_crowd_inds.data(), ncrowd_num,
ncrowd_gt_boxes_data);
Gather<int>(gt_labels->data<int>(), 1, not_crowd_inds.data(), ncrowd_num,
ncrowd_gt_labels_data);
std::vector<Tensor> res;
res.emplace_back(ncrowd_gt_boxes);
res.emplace_back(ncrowd_gt_labels);
return res;
}
template <typename T>
std::vector<Tensor> GetAllFgBgGt(const platform::CPUDeviceContext& ctx,
const Tensor& anchor_by_gt_overlap,
const Tensor& ncrowd_gt_labels,
const float positive_overlap,
const float negative_overlap,
std::minstd_rand engine) {
auto* anchor_by_gt_overlap_data = anchor_by_gt_overlap.data<T>();
int anchor_num = anchor_by_gt_overlap.dims()[0];
int gt_num = anchor_by_gt_overlap.dims()[1];
std::vector<int> fg_inds;
std::vector<int> bg_inds;
std::vector<int> gt_inds;
std::vector<int> tgt_lbl;
std::vector<int> fg_fake;
std::vector<T> bbox_inside_weight;
// Calculate the max IoU between anchors and gt boxes
// Map from anchor to gt box that has highest overlap
auto place = ctx.GetPlace();
Tensor anchor_to_gt_max, anchor_to_gt_argmax, gt_to_anchor_max;
anchor_to_gt_max.mutable_data<T>({anchor_num}, place);
int* argmax = anchor_to_gt_argmax.mutable_data<int>({anchor_num}, place);
gt_to_anchor_max.mutable_data<T>({gt_num}, place);
auto anchor_by_gt_overlap_et =
framework::EigenMatrix<T>::From(anchor_by_gt_overlap);
auto anchor_to_gt_max_et =
framework::EigenVector<T>::Flatten(anchor_to_gt_max);
auto gt_to_anchor_max_et =
framework::EigenVector<T>::Flatten(gt_to_anchor_max);
auto anchor_to_gt_argmax_et =
framework::EigenVector<int>::Flatten(anchor_to_gt_argmax);
anchor_to_gt_max_et =
anchor_by_gt_overlap_et.maximum(Eigen::DSizes<int, 1>(1));
anchor_to_gt_argmax_et =
anchor_by_gt_overlap_et.argmax(1).template cast<int>();
gt_to_anchor_max_et =
anchor_by_gt_overlap_et.maximum(Eigen::DSizes<int, 1>(0));
ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, -1,
-1, positive_overlap, negative_overlap, &fg_inds, &bg_inds,
&tgt_lbl, &fg_fake, &bbox_inside_weight, engine, false);
const int* gt_labels_data = ncrowd_gt_labels.data<int>();
int64_t fg_num = fg_inds.size();
for (int64_t i = 0; i < fg_num; ++i) {
int gt_idx = argmax[fg_inds[i]];
tgt_lbl[i] = gt_labels_data[gt_idx];
}
int bg_num = bg_inds.size();
int fg_fake_num = fg_fake.size();
gt_inds.reserve(fg_fake_num);
for (int i = 0; i < fg_fake_num; ++i) {
gt_inds.emplace_back(argmax[fg_fake[i]]);
}
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t;
Tensor fg_num_t;
int* loc_index_data = loc_index_t.mutable_data<int>({fg_fake_num}, place);
int* score_index_data =
score_index_t.mutable_data<int>({fg_num + bg_num}, place);
int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place);
int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_fake_num}, place);
int* fg_num_data = fg_num_t.mutable_data<int>({1}, place);
T* bbox_inside_weight_data =
bbox_inside_weight_t.mutable_data<T>({fg_fake_num, 4}, place);
std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data);
std::copy(fg_inds.begin(), fg_inds.end(), score_index_data);
std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num);
std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data);
std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data);
std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(),
bbox_inside_weight_data);
fg_num_data[0] = fg_fake.size() + 1;
std::vector<Tensor> loc_score_tgtlbl_gt;
loc_score_tgtlbl_gt.emplace_back(loc_index_t);
loc_score_tgtlbl_gt.emplace_back(score_index_t);
loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t);
loc_score_tgtlbl_gt.emplace_back(gt_inds_t);
loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t);
loc_score_tgtlbl_gt.emplace_back(fg_num_t);
return loc_score_tgtlbl_gt;
}
template <typename T>
class RetinanetTargetAssignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* anchor = context.Input<Tensor>("Anchor"); // (H*W*A) * 4
auto* gt_boxes = context.Input<LoDTensor>("GtBoxes");
auto* gt_labels = context.Input<LoDTensor>("GtLabels");
auto* is_crowd = context.Input<LoDTensor>("IsCrowd");
auto* im_info = context.Input<LoDTensor>("ImInfo");
auto* loc_index = context.Output<LoDTensor>("LocationIndex");
auto* score_index = context.Output<LoDTensor>("ScoreIndex");
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
auto* fg_num = context.Output<LoDTensor>("ForegroundNumber");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"RetinanetTargetAssignOp gt_boxes needs 1 level of LoD");
PADDLE_ENFORCE_EQ(gt_labels->lod().size(), 1UL,
"RetinanetTargetAssignOp gt_boxes needs 1 level of LoD");
PADDLE_ENFORCE_EQ(is_crowd->lod().size(), 1UL,
"RetinanetTargetAssignOp is_crowd needs 1 level of LoD");
int64_t anchor_num = static_cast<int64_t>(anchor->dims()[0]);
int64_t batch_num = static_cast<int64_t>(gt_boxes->lod().back().size() - 1);
float positive_overlap = context.Attr<float>("positive_overlap");
float negative_overlap = context.Attr<float>("negative_overlap");
int64_t max_num = batch_num * anchor_num;
auto place = context.GetPlace();
loc_index->mutable_data<int>({max_num}, place);
score_index->mutable_data<int>({max_num}, place);
tgt_bbox->mutable_data<T>({max_num, 4}, place);
tgt_lbl->mutable_data<int>({max_num, 1}, place);
bbox_inside_weight->mutable_data<T>({max_num, 4}, place);
fg_num->mutable_data<int>({batch_num, 1}, place);
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
std::random_device rnd;
std::minstd_rand engine;
int seed = rnd();
engine.seed(seed);
framework::LoD lod_loc, loc_score, lod_fg;
std::vector<size_t> lod0_loc(1, 0);
std::vector<size_t> lod0_score(1, 0);
std::vector<size_t> lod0_fg(1, 0);
int total_loc_num = 0;
int total_score_num = 0;
int total_fg_num = 0;
auto gt_boxes_lod = gt_boxes->lod().back();
auto gt_labels_lod = gt_labels->lod().back();
auto is_crowd_lod = is_crowd->lod().back();
for (int i = 0; i < batch_num; ++i) {
Tensor gt_boxes_slice =
gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]);
Tensor gt_labels_slice =
gt_labels->Slice(gt_labels_lod[i], gt_labels_lod[i + 1]);
Tensor is_crowd_slice =
is_crowd->Slice(is_crowd_lod[i], is_crowd_lod[i + 1]);
Tensor im_info_slice = im_info->Slice(i, i + 1);
auto* im_info_data = im_info_slice.data<T>();
auto im_height = im_info_data[0];
auto im_width = im_info_data[1];
auto im_scale = im_info_data[2];
// Filter straddle anchor
std::vector<Tensor> filter_output =
FilterStraddleAnchor<T>(dev_ctx, anchor, -1, im_height, im_width);
Tensor inds_inside = filter_output[0];
Tensor inside_anchor = filter_output[1];
// Filter crowd gt
std::vector<Tensor> ncrowd_output = FilterCrowdGtBoxLabel<T>(
dev_ctx, &gt_boxes_slice, &gt_labels_slice, &is_crowd_slice);
Tensor ncrowd_gt_boxes = ncrowd_output[0];
Tensor ncrowd_gt_labels = ncrowd_output[1];
auto ncrowd_gt_boxes_et =
framework::EigenTensor<T, 2>::From(ncrowd_gt_boxes);
ncrowd_gt_boxes_et = ncrowd_gt_boxes_et * im_scale;
Tensor anchor_by_gt_overlap;
anchor_by_gt_overlap.mutable_data<T>(
{inside_anchor.dims()[0], ncrowd_gt_boxes.dims()[0]}, place);
BboxOverlaps<T>(inside_anchor, ncrowd_gt_boxes, &anchor_by_gt_overlap);
auto loc_score_tgtlbl_gt =
GetAllFgBgGt<T>(dev_ctx, anchor_by_gt_overlap, ncrowd_gt_labels,
positive_overlap, negative_overlap, engine);
Tensor sampled_loc_index = loc_score_tgtlbl_gt[0];
Tensor sampled_score_index = loc_score_tgtlbl_gt[1];
Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2];
Tensor sampled_gt_index = loc_score_tgtlbl_gt[3];
Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4];
Tensor sampled_fg_num = loc_score_tgtlbl_gt[5];
int loc_num = sampled_loc_index.dims()[0];
int score_num = sampled_score_index.dims()[0];
// unmap to all anchor
Tensor sampled_loc_index_unmap, sampled_score_index_unmap;
sampled_loc_index_unmap.mutable_data<int>({loc_num}, place);
sampled_score_index_unmap.mutable_data<int>({score_num}, place);
Gather<int>(inds_inside.data<int>(), 1, sampled_loc_index.data<int>(),
loc_num, sampled_loc_index_unmap.data<int>());
Gather<int>(inds_inside.data<int>(), 1, sampled_score_index.data<int>(),
score_num, sampled_score_index_unmap.data<int>());
// get target bbox deltas
Tensor sampled_anchor, sampled_gt, sampled_tgt_bbox;
auto* sampled_anchor_data =
sampled_anchor.mutable_data<T>({loc_num, 4}, place);
auto* sampled_gt_data = sampled_gt.mutable_data<T>({loc_num, 4}, place);
Gather<T>(anchor->data<T>(), 4, sampled_loc_index_unmap.data<int>(),
loc_num, sampled_anchor_data);
Gather<T>(ncrowd_gt_boxes.data<T>(), 4, sampled_gt_index.data<int>(),
loc_num, sampled_gt_data);
sampled_tgt_bbox.mutable_data<T>({loc_num, 4}, place);
BoxToDelta<T>(loc_num, sampled_anchor, sampled_gt, nullptr, false,
&sampled_tgt_bbox);
// Add anchor offset
int anchor_offset = i * anchor_num;
auto sampled_loc_index_unmap_et =
framework::EigenTensor<int, 1>::From(sampled_loc_index_unmap);
sampled_loc_index_unmap_et = sampled_loc_index_unmap_et + anchor_offset;
auto sampled_score_index_unmap_et =
framework::EigenTensor<int, 1>::From(sampled_score_index_unmap);
sampled_score_index_unmap_et =
sampled_score_index_unmap_et + anchor_offset;
AppendRpns<int>(loc_index, total_loc_num, &sampled_loc_index_unmap);
AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap);
AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox);
AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl);
AppendRpns<T>(bbox_inside_weight, total_loc_num * 4,
&sampled_bbox_inside_weight);
AppendRpns<int>(fg_num, total_fg_num, &sampled_fg_num);
total_loc_num += loc_num;
total_score_num += score_num;
total_fg_num += 1;
lod0_loc.emplace_back(total_loc_num);
lod0_score.emplace_back(total_score_num);
lod0_fg.emplace_back(total_fg_num);
}
PADDLE_ENFORCE_LE(total_loc_num, max_num);
PADDLE_ENFORCE_LE(total_score_num, max_num);
PADDLE_ENFORCE_LE(total_fg_num, batch_num);
lod_loc.emplace_back(lod0_loc);
loc_score.emplace_back(lod0_score);
lod_fg.emplace_back(lod0_fg);
loc_index->set_lod(lod_loc);
score_index->set_lod(loc_score);
tgt_bbox->set_lod(lod_loc);
tgt_lbl->set_lod(loc_score);
bbox_inside_weight->set_lod(lod_loc);
fg_num->set_lod(lod_fg);
loc_index->Resize({total_loc_num});
score_index->Resize({total_score_num});
tgt_bbox->Resize({total_loc_num, 4});
tgt_lbl->Resize({total_score_num, 1});
bbox_inside_weight->Resize({total_loc_num, 4});
fg_num->Resize({total_fg_num, 1});
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -582,3 +1027,9 @@ REGISTER_OPERATOR(rpn_target_assign, ops::RpnTargetAssignOp, ...@@ -582,3 +1027,9 @@ REGISTER_OPERATOR(rpn_target_assign, ops::RpnTargetAssignOp,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel<float>, REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel<float>,
ops::RpnTargetAssignKernel<double>); ops::RpnTargetAssignKernel<double>);
REGISTER_OPERATOR(retinanet_target_assign, ops::RetinanetTargetAssignOp,
ops::RetinanetTargetAssignOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(retinanet_target_assign,
ops::RetinanetTargetAssignKernel<float>,
ops::RetinanetTargetAssignKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class SigmoidFocalLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto fg_dims = ctx->GetInputDim("FgNum");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"The last dimension of input(Label) should be 1.");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class SigmoidFocalLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto fg_dims = ctx->GetInputDim("FgNum");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape.");
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"The last dimension of input(Label) should be 1.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(dout_dims, 0, rank),
"Input(X) and Input(Out@Grad) shall have the same shape.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class SigmoidFocalLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D], "
"where N is the batch size and D is the number of classes "
"(excluding background). This input is a tensor of logits "
"computed by the previous operator.");
AddInput("Label",
"(Tensor, default Tensor<int>), a 2-D tensor with shape [N, 1]. "
"This input is a tensor of probabilistic labels.");
AddInput("FgNum",
"(Tensor, default Tensor<int>), a 1-D tensor with shape [1]. "
"This input is the number of foreground.");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D]. "
"This output is the focal loss.");
AddAttr<float>(
"gamma",
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
"easy and hard examples. "
"A float scalar with default value 2.0.")
.SetDefault(2.0);
AddAttr<float>(
"alpha",
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
"positive and negative examples. "
"A float scalar with default value 0.5.")
.SetDefault(0.25);
AddComment(R"DOC(
Sigmoid Focal Loss Operator.
Focal loss is used to address the foreground-background class imbalance existed
on the training phase of one-stage detectors. This operator computes the sigmoid
value for each element in the input tensor, after which focal loss is measured.
The focal loss is given as follows:
$$Loss_j = (-Label_j * alpha * \pow(1 - \sigma(X_j), gamma) * \log(\sigma(X_j)) -
(1 - Labels_j) * (1 - alpha) * \pow(\sigma(X_j), gamma) * \log(1 - \sigma(X_j)))
/ FgNum, j = 1,...,K$$
We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$.
)DOC");
}
};
class SigmoidFocalLossGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sigmoid_focal_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput("FgNum", Input("FgNum"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp,
ops::SigmoidFocalLossOpMaker,
ops::SigmoidFocalLossGradOpDescMaker);
REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp);
REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss,
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss_grad,
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GPUSigmoidFocalLossForward(const T *x_data,
const int *label_data,
const int *fg_num_data,
const T gamma, const T alpha,
const int num_classes,
const int limit, T *out_data) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
int a = i / num_classes; // current sample
int d = i % num_classes; // current class
int g = label_data[a]; // target
// check whether the input data is positive or negative
// the target classes are in range 1-81
// and the d is in range 0-80
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
// p = 1. / 1. + expf(-x)
T p = 1. / (1. + real_exp(-x));
// (1 - p)**gamma * log(p)
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
real_log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
T term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0))));
out_data[i] = 0.0;
out_data[i] += -c_pos * term_pos * s_pos;
out_data[i] += -c_neg * term_neg * s_neg;
}
}
template <typename T>
__global__ void GPUSigmoidFocalLossBackward(
const T *x_data, const int *label_data, const int *fg_num_data,
const T gamma, const T alpha, const int num_classes, const T *dout_data,
const int limit, T *dx_data) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T dout = dout_data[i];
int a = i / num_classes; // current sample
int d = i % num_classes; // current class
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
int g = label_data[a];
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T p = 1. / (1. + real_exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
(1. - p - (p * gamma * real_log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg =
std::pow(p, gamma) *
((-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[i] = 0.0;
dx_data[i] += -c_pos * s_pos * term_pos;
dx_data[i] += -c_neg * s_neg * term_neg;
dx_data[i] = dx_data[i] * dout;
}
}
template <typename DeviceContext, typename T>
class GPUSigmoidFocalLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
Tensor *Out = context.Output<Tensor>("Out");
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
auto out_data = Out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.cuda_device_context();
int limit = Out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidFocalLossForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
num_classes, limit, out_data);
}
};
template <typename DeviceContext, typename T>
class GPUSigmoidFocalLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
auto &dev_ctx = context.cuda_device_context();
int limit = dX->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidFocalLossBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
num_classes, dOut->data<T>(), limit, dx_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sigmoid_focal_loss,
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
sigmoid_focal_loss_grad,
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <limits>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SigmoidFocalLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
Tensor *Out = context.Output<Tensor>("Out");
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto out_data = Out->mutable_data<T>(context.GetPlace());
int limit = Out->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<int>();
auto fg_num_data = FgNum->data<int>();
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
int a = idx / num_classes; // current sample
int d = idx % num_classes; // current class
int g = label_data[a]; // target
// Check whether the input data is positive or negative
// The target classes are in range 1-81
// and the d is in range 0-80
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
// p = 1. / 1. + expf(-x)
T p = 1. / (1. + std::exp(-x));
// (1 - p)**gamma * log(p) where
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
std::log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
T term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - std::log(1. + std::exp(x - 2. * x * (x >= 0))));
out_data[idx] = 0.0;
out_data[idx] += -c_pos * term_pos * s_pos;
out_data[idx] += -c_neg * term_neg * s_neg;
}
}
};
template <typename DeviceContext, typename T>
class SigmoidFocalLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
int limit = dX->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<int>();
auto fg_num_data = FgNum->data<int>();
auto dout_data = dOut->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
int a = idx / num_classes; // current sample
int d = idx % num_classes; // current class
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = static_cast<T>((1.0 - alpha) / fg_num);
T s_pos = alpha / fg_num;
int g = label_data[a];
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T p = 1. / (1. + std::exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
(1. - p - (p * gamma * std::log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg = std::pow(p, gamma) *
((-1. * x * (x >= 0) -
std::log(1. + std::exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[idx] = 0.0;
dx_data[idx] += -c_pos * s_pos * term_pos;
dx_data[idx] += -c_neg * s_neg * term_neg;
dx_data[idx] = dx_data[idx] * dout_data[idx];
}
}
};
} // namespace operators
} // namespace paddle
...@@ -40,6 +40,8 @@ __all__ = [ ...@@ -40,6 +40,8 @@ __all__ = [
'ssd_loss', 'ssd_loss',
'detection_map', 'detection_map',
'rpn_target_assign', 'rpn_target_assign',
'retinanet_target_assign',
'sigmoid_focal_loss',
'anchor_generator', 'anchor_generator',
'roi_perspective_transform', 'roi_perspective_transform',
'generate_proposal_labels', 'generate_proposal_labels',
...@@ -52,12 +54,171 @@ __all__ = [ ...@@ -52,12 +54,171 @@ __all__ = [
'yolo_box', 'yolo_box',
'box_clip', 'box_clip',
'multiclass_nms', 'multiclass_nms',
'retinanet_detection_output',
'distribute_fpn_proposals', 'distribute_fpn_proposals',
'box_decoder_and_assign', 'box_decoder_and_assign',
'collect_fpn_proposals', 'collect_fpn_proposals',
] ]
def retinanet_target_assign(bbox_pred,
cls_logits,
anchor_box,
anchor_var,
gt_boxes,
gt_labels,
is_crowd,
im_info,
num_classes=1,
positive_overlap=0.5,
negative_overlap=0.4):
"""
**Target Assign Layer for Retinanet .**
This layer can be, for given the Intersection-over-Union (IoU) overlap
between anchors and ground truth boxes, to assign classification and
regression targets to each anchor, these target labels are used for training
retinanet. Every anchor is assigned with a length :attr:`num_classes`
one-hot vector of classification targets, and a 4-vector of box regression
targets. The assignment rules are as followed:
1. Anchors are assigned to ground-truth boxes when: (i) it has the highest
IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher
than positive_overlap(0.5) with any ground-truth box.
2. Anchors are assigned to background when its IoU ratio is lower than
negative_overlap (0.4) for all ground-truth boxes.
When an anchor is assigned with a ground-truth box which is the i-th category,
the i-th entry in its C vector of targets is set to 1 and all other entries
are set to 0. When an anchor is assigned with background, all entries are set
to 0. Anchors that are not assigned do not contribute to the training
objective. The regression targets are the encoded ground-truth boxes
associated with the assigned anchors.
Args:
bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax].
cls_logits(Variable): A 3-D Tensor with shape [N, M, C] represents the
predicted confidence predictions. N is the batch size, C is the
number of classes (excluding background), M is number of bounding boxes.
anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box,
if the input is image feature map, they are close to the origin
of the coordinate system. [xmax, ymax] is the right bottom
coordinate of the anchor box.
anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded
variances of anchors.
gt_boxes(Variable): The ground-truth bounding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input.
gt_labels(variable): The ground-truth labels are a 2D LoDTensor with
shape [Ng, 1], Ng is the total number of ground-truth labels of
mini-batch input.
is_crowd(Variable): A 1-D LoDTensor which indicates ground-truth is crowd.
im_info(Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
3 is the height, width and scale.
num_classes(int32): The number of classes.
positive_overlap(float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be a positive
example.
negative_overlap(float): Maximum overlap allowed between an anchor
and ground-truth box for the (anchor, gt box) pair to be a negative
examples.
Returns:
tuple:
A tuple(predicted_scores, predicted_location, target_label,
target_bbox, bbox_inside_weight, fg_num) is returned. The
predicted_scores and predicted_location are the predicted result
of the retinanet.The target_label and target_bbox are the ground
truth, respectively. The predicted_location is a 2D Tensor with
shape [F, 4], and the shape of target_bbox is same as the shape of
the predicted_location, F is the number of the foreground
anchors. The predicted_scores is a 2D Tensor with shape
[F + B, C], and the shape of target_label is [F + B, 1], B is the
number of the background anchors, the F and B is depends on the
input of this operator. Bbox_inside_weight represents whether the
predicted location is fake foreground or not and the shape is [F, 4].
Fg_num is the foreground number (including fake foreground) which
is needed by focal loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
bbox_pred = layers.data(name='bbox_pred', shape=[1, 100, 4],
append_batch_size=False, dtype='float32')
cls_logits = layers.data(name='cls_logits', shape=[1, 100, 10],
append_batch_size=False, dtype='float32')
anchor_box = layers.data(name='anchor_box', shape=[100, 4],
append_batch_size=False, dtype='float32')
anchor_var = layers.data(name='anchor_var', shape=[100, 4],
append_batch_size=False, dtype='float32')
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32')
gt_labels = layers.data(name='gt_labels', shape=[10, 1],
append_batch_size=False, dtype='float32')
is_crowd = fluid.layers.data(name='is_crowd', shape=[1],
append_batch_size=False, dtype='float32')
im_info = fluid.layers.data(name='im_infoss', shape=[1, 3],
append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target, bbox_inside_weight, fg_num =
fluid.layers.retinanet_target_assign(bbox_pred, cls_logits, anchor_box,
anchor_var, gt_boxes, gt_labels, is_crowd, im_info, 10)
"""
helper = LayerHelper('retinanet_target_assign', **locals())
# Assign target label to anchors
loc_index = helper.create_variable_for_type_inference(dtype='int32')
score_index = helper.create_variable_for_type_inference(dtype='int32')
target_label = helper.create_variable_for_type_inference(dtype='int32')
target_bbox = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
bbox_inside_weight = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
fg_num = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type="retinanet_target_assign",
inputs={
'Anchor': anchor_box,
'GtBoxes': gt_boxes,
'GtLabels': gt_labels,
'IsCrowd': is_crowd,
'ImInfo': im_info
},
outputs={
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': target_label,
'TargetBBox': target_bbox,
'BBoxInsideWeight': bbox_inside_weight,
'ForegroundNumber': fg_num
},
attrs={
'positive_overlap': positive_overlap,
'negative_overlap': negative_overlap
})
loc_index.stop_gradient = True
score_index.stop_gradient = True
target_label.stop_gradient = True
target_bbox.stop_gradient = True
bbox_inside_weight.stop_gradient = True
fg_num.stop_gradient = True
cls_logits = nn.reshape(x=cls_logits, shape=(-1, num_classes))
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight, fg_num
def rpn_target_assign(bbox_pred, def rpn_target_assign(bbox_pred,
cls_logits, cls_logits,
anchor_box, anchor_box,
...@@ -210,6 +371,74 @@ def rpn_target_assign(bbox_pred, ...@@ -210,6 +371,74 @@ def rpn_target_assign(bbox_pred,
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25):
"""
**Sigmoid Focal Loss Operator.**
Focal loss is used to address the foreground-background class imbalance existed
on the training phase of one-stage detectors. This operator computes the sigmoid
value for each element in the input tensor, after which focal loss is measured.
The focal loss is given as followed:
.. math::
loss_j = (-label_j * alpha * {(1 - \\sigma(x_j))}^{gamma} * \\log(\\sigma(x_j)) -
(1 - labels_j) * (1 - alpha) * {(\sigma(x_j)}^{ gamma} * \\log(1 - \\sigma(x_j)))
/ fg\_num, j = 1,...,K
We know that
.. math::
\\sigma(x_j) = \\frac{1}{1 + \\exp(-x_j)}
Args:
x(Variable): A 2-D tensor with shape [N, D], where N is the batch size and D is the number
of classes (excluding background). This input is a tensor of logits computed by the
previous operator.
label(Variable): A 2-D tensor with shape [N, 1], which is the probabilistic labels.
fg_num(Variable): A 1-D tensor with shape [1], which is the number of foreground.
gamma(float): Hyper-parameter to balance the easy and hard examples. Default value is
set to 2.0.
alpha(float): Hyper-parameter to balance the positive and negative example. Default value
is set to 0.25.
Returns:
out(Variable): A 2-D tensor with shape [N, D], which is the focal loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(
name='data', shape=[10,80], append_batch_size=False, dtype='float32')
label = fluid.layers.data(
name='label', shape=[10,1], append_batch_size=False, dtype='int32')
fg_num = fluid.layers.data(
name='fg_num', shape=[1], append_batch_size=False, dtype='int32')
loss = fluid.layers.sigmoid_focal_loss(x=input,
label=label,
fg_num=fg_num,
gamma=2.,
alpha=0.25)
"""
helper = LayerHelper("sigmoid_focal_loss", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="sigmoid_focal_loss",
inputs={"X": x,
"Label": label,
"FgNum": fg_num},
attrs={"gamma": gamma,
'alpha': alpha},
outputs={"Out": out})
return out
def detection_output(loc, def detection_output(loc,
scores, scores,
prior_box, prior_box,
...@@ -2320,6 +2549,113 @@ def box_clip(input, im_info, name=None): ...@@ -2320,6 +2549,113 @@ def box_clip(input, im_info, name=None):
return output return output
def retinanet_detection_output(bboxes,
scores,
anchors,
im_info,
score_threshold=0.05,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
nms_eta=1.):
"""
**Detection Output Layer for Retinanet.**
This operation is to get the detection results by performing following
steps:
1. Decode top-scoring bounding box predictions per FPN level according
to the anchor boxes.
2. Merge top predictions from all levels and apply multi-class non
maximum suppression (NMS) on them to get the final detections.
Args:
bboxes(List): A list of tensors from multiple FPN levels. Each
element is a 3-D Tensor with shape [N, Mi, 4] representing the
predicted locations of Mi bounding boxes. N is the batch size,
Mi is the number of bounding boxes from i-th FPN level and each
bounding box has four coordinate values and the layout is
[xmin, ymin, xmax, ymax].
scores(List): A list of tensors from multiple FPN levels. Each
element is a 3-D Tensor with shape [N, Mi, C] representing the
predicted confidence predictions. N is the batch size, C is the
class number (excluding background), Mi is the number of bounding
boxes from i-th FPN level. For each bounding box, there are total
C scores.
anchors(List): A 2-D Tensor with shape [Mi, 4] represents the locations
of Mi anchor boxes from all FPN level. Each bounding box has four
coordinate values and the layout is [xmin, ymin, xmax, ymax].
im_info(Variable): A 2-D LoDTensor with shape [N, 3] represents the
image information. N is the batch size, each image information
includes height, width and scale.
score_threshold(float): Threshold to filter out bounding boxes
with a confidence score.
nms_top_k(int): Maximum number of detections per FPN layer to be
kept according to the confidences before NMS.
keep_top_k(int): Number of total bounding boxes to be kept per image after
NMS step. -1 means keeping all bounding boxes after NMS step.
nms_threshold(float): The threshold to be used in NMS.
nms_eta(float): The parameter for adaptive NMS.
Returns:
Variable:
The detection output is a LoDTensor with shape [No, 6].
Each row has six values: [label, confidence, xmin, ymin, xmax, ymax].
`No` is the total number of detections in this mini-batch. For each
instance, the offsets in first dimension are called LoD, the offset
number is N + 1, N is the batch size. The i-th image has
`LoD[i + 1] - LoD[i]` detected results, if it is 0, the i-th image
has no detected results. If all images have no detected results,
LoD will be set to 0, and the output tensor is empty (None).
Examples:
.. code-block:: python
import paddle.fluid as fluid
bboxes = layers.data(name='bboxes', shape=[1, 21, 4],
append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[1, 21, 10],
append_batch_size=False, dtype='float32')
anchors = layers.data(name='anchors', shape=[21, 4],
append_batch_size=False, dtype='float32')
im_info = layers.data(name="im_info", shape=[1, 3],
append_batch_size=False, dtype='float32')
nmsed_outs = fluid.layers.retinanet_detection_output(
bboxes=[bboxes, bboxes],
scores=[scores, scores],
anchors=[anchors, anchors],
im_info=im_info,
score_threshold=0.05,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
nms_eta=1.)
"""
helper = LayerHelper('retinanet_detection_output', **locals())
output = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('scores'))
helper.append_op(
type="retinanet_detection_output",
inputs={
'BBoxes': bboxes,
'Scores': scores,
'Anchors': anchors,
'ImInfo': im_info
},
attrs={
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'keep_top_k': keep_top_k,
'nms_eta': 1.,
},
outputs={'Out': output})
output.stop_gradient = True
return output
def multiclass_nms(bboxes, def multiclass_nms(bboxes,
scores, scores,
score_threshold, score_threshold,
......
...@@ -2018,6 +2018,110 @@ class TestBook(LayerTest): ...@@ -2018,6 +2018,110 @@ class TestBook(LayerTest):
trans_std=0.1) trans_std=0.1)
return (out) return (out)
def test_retinanet_target_assign(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
bbox_pred = layers.data(
name='bbox_pred',
shape=[1, 100, 4],
append_batch_size=False,
dtype='float32')
cls_logits = layers.data(
name='cls_logits',
shape=[1, 100, 10],
append_batch_size=False,
dtype='float32')
anchor_box = layers.data(
name='anchor_box',
shape=[100, 4],
append_batch_size=False,
dtype='float32')
anchor_var = layers.data(
name='anchor_var',
shape=[100, 4],
append_batch_size=False,
dtype='float32')
gt_boxes = layers.data(
name='gt_boxes',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
gt_labels = layers.data(
name='gt_labels',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[1],
append_batch_size=False,
dtype='float32')
im_info = layers.data(
name='im_info',
shape=[1, 3],
append_batch_size=False,
dtype='float32')
return (layers.retinanet_target_assign(
bbox_pred, cls_logits, anchor_box, anchor_var, gt_boxes,
gt_labels, is_crowd, im_info, 10))
def test_sigmoid_focal_loss(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='data',
shape=[10, 80],
append_batch_size=False,
dtype='float32')
label = layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='int32')
fg_num = layers.data(
name='fg_num',
shape=[1],
append_batch_size=False,
dtype='int32')
out = fluid.layers.sigmoid_focal_loss(
x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25)
return (out)
def test_retinanet_detection_output(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
bboxes = layers.data(
name='bboxes',
shape=[1, 21, 4],
append_batch_size=False,
dtype='float32')
scores = layers.data(
name='scores',
shape=[1, 21, 10],
append_batch_size=False,
dtype='float32')
anchors = layers.data(
name='anchors',
shape=[21, 4],
append_batch_size=False,
dtype='float32')
im_info = layers.data(
name="im_info",
shape=[1, 3],
append_batch_size=False,
dtype='float32')
nmsed_outs = layers.retinanet_detection_output(
bboxes=[bboxes, bboxes],
scores=[scores, scores],
anchors=[anchors, anchors],
im_info=im_info,
score_threshold=0.05,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
nms_eta=1.)
return (nmsed_outs)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
#Licensed under the Apache License, Version 2.0 (the "License")
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import copy
from op_test import OpTest
from test_anchor_generator_op import anchor_generator_in_python
from test_multiclass_nms_op import iou
from test_multiclass_nms_op import nms
def multiclass_nms(prediction, class_num, keep_top_k, nms_threshold):
selected_indices = {}
num_det = 0
for c in range(class_num):
if c not in prediction.keys():
continue
cls_dets = prediction[c]
all_scores = np.zeros(len(cls_dets))
for i in range(all_scores.shape[0]):
all_scores[i] = cls_dets[i][4]
indices = nms(cls_dets, all_scores, 0.0, nms_threshold, -1, False, 1.0)
selected_indices[c] = indices
num_det += len(indices)
score_index = []
for c, indices in selected_indices.items():
for idx in indices:
score_index.append((prediction[c][idx][4], c, idx))
sorted_score_index = sorted(
score_index, key=lambda tup: tup[0], reverse=True)
if keep_top_k > -1 and num_det > keep_top_k:
sorted_score_index = sorted_score_index[:keep_top_k]
num_det = keep_top_k
nmsed_outs = []
for s, c, idx in sorted_score_index:
xmin = prediction[c][idx][0]
ymin = prediction[c][idx][1]
xmax = prediction[c][idx][2]
ymax = prediction[c][idx][3]
nmsed_outs.append([c + 1, s, xmin, ymin, xmax, ymax])
return nmsed_outs, num_det
def retinanet_detection_out(boxes_list, scores_list, anchors_list, im_info,
score_threshold, nms_threshold, nms_top_k,
keep_top_k):
class_num = scores_list[0].shape[-1]
im_height, im_width, im_scale = im_info
num_level = len(scores_list)
prediction = {}
for lvl in range(num_level):
scores_per_level = scores_list[lvl]
scores_per_level = scores_per_level.flatten()
bboxes_per_level = boxes_list[lvl]
bboxes_per_level = bboxes_per_level.flatten()
anchors_per_level = anchors_list[lvl]
anchors_per_level = anchors_per_level.flatten()
thresh = score_threshold if lvl < (num_level - 1) else 0.0
selected_indices = np.argwhere(scores_per_level > thresh)
scores = scores_per_level[selected_indices]
sorted_indices = np.argsort(-scores, axis=0, kind='mergesort')
if nms_top_k > -1 and nms_top_k < sorted_indices.shape[0]:
sorted_indices = sorted_indices[:nms_top_k]
for i in range(sorted_indices.shape[0]):
idx = selected_indices[sorted_indices[i]]
idx = idx[0][0]
a = int(idx / class_num)
c = int(idx % class_num)
box_offset = a * 4
anchor_box_width = anchors_per_level[
box_offset + 2] - anchors_per_level[box_offset] + 1
anchor_box_height = anchors_per_level[
box_offset + 3] - anchors_per_level[box_offset + 1] + 1
anchor_box_center_x = anchors_per_level[
box_offset] + anchor_box_width / 2
anchor_box_center_y = anchors_per_level[box_offset +
1] + anchor_box_height / 2
target_box_center_x = bboxes_per_level[
box_offset] * anchor_box_width + anchor_box_center_x
target_box_center_y = bboxes_per_level[
box_offset + 1] * anchor_box_height + anchor_box_center_y
target_box_width = math.exp(bboxes_per_level[box_offset +
2]) * anchor_box_width
target_box_height = math.exp(bboxes_per_level[
box_offset + 3]) * anchor_box_height
pred_box_xmin = target_box_center_x - target_box_width / 2
pred_box_ymin = target_box_center_y - target_box_height / 2
pred_box_xmax = target_box_center_x + target_box_width / 2 - 1
pred_box_ymax = target_box_center_y + target_box_height / 2 - 1
pred_box_xmin = pred_box_xmin / im_scale
pred_box_ymin = pred_box_ymin / im_scale
pred_box_xmax = pred_box_xmax / im_scale
pred_box_ymax = pred_box_ymax / im_scale
pred_box_xmin = max(
min(pred_box_xmin, np.round(im_width / im_scale) - 1), 0.)
pred_box_ymin = max(
min(pred_box_ymin, np.round(im_height / im_scale) - 1), 0.)
pred_box_xmax = max(
min(pred_box_xmax, np.round(im_width / im_scale) - 1), 0.)
pred_box_ymax = max(
min(pred_box_ymax, np.round(im_height / im_scale) - 1), 0.)
if c not in prediction.keys():
prediction[c] = []
prediction[c].append([
pred_box_xmin, pred_box_ymin, pred_box_xmax, pred_box_ymax,
scores_per_level[idx]
])
nmsed_outs, nmsed_num = multiclass_nms(prediction, class_num, keep_top_k,
nms_threshold)
return nmsed_outs, nmsed_num
def batched_retinanet_detection_out(boxes, scores, anchors, im_info,
score_threshold, nms_threshold, nms_top_k,
keep_top_k):
batch_size = scores[0].shape[0]
det_outs = []
lod = []
for n in range(batch_size):
boxes_per_batch = []
scores_per_batch = []
num_level = len(scores)
for lvl in range(num_level):
boxes_per_batch.append(boxes[lvl][n])
scores_per_batch.append(scores[lvl][n])
nmsed_outs, nmsed_num = retinanet_detection_out(
boxes_per_batch, scores_per_batch, anchors, im_info[n],
score_threshold, nms_threshold, nms_top_k, keep_top_k)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
det_outs.extend(nmsed_outs)
return det_outs, lod
class TestRetinanetDetectionOutOp1(OpTest):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
def init_test_input(self):
anchor_num = len(self.aspect_ratios) * self.scales_per_octave
num_levels = self.max_level - self.min_level + 1
self.scores_list = []
self.bboxes_list = []
self.anchors_list = []
for i in range(num_levels):
layer_h = self.layer_h[i]
layer_w = self.layer_w[i]
input_feat = np.random.random((self.batch_size, self.input_channels,
layer_h, layer_w)).astype('float32')
score = np.random.random(
(self.batch_size, self.class_num * anchor_num, layer_h,
layer_w)).astype('float32')
score = np.transpose(score, [0, 2, 3, 1])
score = score.reshape((self.batch_size, -1, self.class_num))
box = np.random.random((self.batch_size, self.box_size * anchor_num,
layer_h, layer_w)).astype('float32')
box = np.transpose(box, [0, 2, 3, 1])
box = box.reshape((self.batch_size, -1, self.box_size))
anchor_sizes = []
for octave in range(self.scales_per_octave):
anchor_sizes.append(
float(self.anchor_strides[i] * (2**octave)) /
float(self.scales_per_octave) * self.anchor_scale)
anchor, var = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=anchor_sizes,
aspect_ratios=self.aspect_ratios,
variances=[1.0, 1.0, 1.0, 1.0],
stride=[self.anchor_strides[i], self.anchor_strides[i]],
offset=0.5)
anchor = np.reshape(anchor, [-1, 4])
self.scores_list.append(score.astype('float32'))
self.bboxes_list.append(box.astype('float32'))
self.anchors_list.append(anchor.astype('float32'))
self.im_info = np.array([[256., 256., 1.5]]).astype(
'float32') #im_height, im_width, scale
def setUp(self):
self.set_argument()
self.init_test_input()
nmsed_outs, lod = batched_retinanet_detection_out(
self.bboxes_list, self.scores_list, self.anchors_list, self.im_info,
self.score_threshold, self.nms_threshold, self.nms_top_k,
self.keep_top_k)
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'retinanet_detection_output'
self.inputs = {
'BBoxes': [('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]),
('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3]),
('b4', self.bboxes_list[4])],
'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]),
('s2', self.scores_list[2]), ('s3', self.scores_list[3]),
('s4', self.scores_list[4])],
'Anchors':
[('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]),
('a2', self.anchors_list[2]), ('a3', self.anchors_list[3]),
('a4', self.anchors_list[4])],
'ImInfo': (self.im_info, [[1, ]])
}
self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = {
'score_threshold': self.score_threshold,
'nms_top_k': self.nms_top_k,
'nms_threshold': self.nms_threshold,
'keep_top_k': self.keep_top_k,
'nms_eta': 1.,
}
def test_check_output(self):
self.check_output()
class TestRetinanetDetectionOutOp2(OpTest):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
# Here test the case there the shape of each FPN level
# is irrelevant.
self.layer_h = [1, 4, 8, 8, 16]
self.layer_w = [1, 4, 8, 8, 16]
class TestRetinanetDetectionOutOpNo3(TestRetinanetDetectionOutOp1):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
class TestRetinanetDetectionOutOpNo4(TestRetinanetDetectionOutOp1):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 2
self.max_level = 5
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
def setUp(self):
self.set_argument()
self.init_test_input()
nmsed_outs, lod = batched_retinanet_detection_out(
self.bboxes_list, self.scores_list, self.anchors_list, self.im_info,
self.score_threshold, self.nms_threshold, self.nms_top_k,
self.keep_top_k)
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'retinanet_detection_output'
self.inputs = {
'BBoxes':
[('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]),
('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3])],
'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]),
('s2', self.scores_list[2]),
('s3', self.scores_list[3])],
'Anchors':
[('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]),
('a2', self.anchors_list[2]), ('a3', self.anchors_list[3])],
'ImInfo': (self.im_info, [[1, ]])
}
self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = {
'score_threshold': self.score_threshold,
'nms_top_k': self.nms_top_k,
'nms_threshold': self.nms_threshold,
'keep_top_k': self.keep_top_k,
'nms_eta': 1.,
}
def test_check_output(self):
self.check_output()
class TestRetinanetDetectionOutOpNo5(TestRetinanetDetectionOutOp1):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 100
self.keep_top_k = 10
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
if __name__ == '__main__':
unittest.main()
...@@ -167,6 +167,105 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -167,6 +167,105 @@ def rpn_target_assign_in_python(all_anchors,
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights
def retinanet_target_assign(anchor_by_gt_overlap, gt_labels, positive_overlap,
negative_overlap):
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1)
anchor_to_gt_max = anchor_by_gt_overlap[np.arange(
anchor_by_gt_overlap.shape[0]), anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0)
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, np.arange(
anchor_by_gt_overlap.shape[1])]
anchors_with_max_overlap = np.where(
anchor_by_gt_overlap == gt_to_anchor_max)[0]
labels = np.ones((anchor_by_gt_overlap.shape[0], ), dtype=np.int32) * -1
labels[anchors_with_max_overlap] = 1
labels[anchor_to_gt_max >= positive_overlap] = 1
fg_inds = np.where(labels == 1)[0]
bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32)
bg_inds = np.where(anchor_to_gt_max < negative_overlap)[0]
enable_inds = bg_inds
fg_fake_inds = np.array([], np.int32)
fg_value = np.array([fg_inds[0]], np.int32)
fake_num = 0
for bg_id in enable_inds:
if bg_id in fg_inds:
fake_num += 1
fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
labels[enable_inds] = 0
bbox_inside_weight[fake_num:, :] = 1
fg_inds = np.where(labels == 1)[0]
bg_inds = np.where(labels == 0)[0]
loc_index = np.hstack([fg_fake_inds, fg_inds])
score_index = np.hstack([fg_inds, bg_inds])
score_index_tmp = np.hstack([fg_inds])
labels = labels[score_index]
gt_inds = anchor_to_gt_argmax[loc_index]
label_inds = anchor_to_gt_argmax[score_index_tmp]
labels[0:len(fg_inds)] = np.squeeze(gt_labels[label_inds])
fg_num = len(fg_fake_inds) + len(fg_inds) + 1
assert not np.any(labels == -1), "Wrong labels with -1"
return loc_index, score_index, labels, gt_inds, bbox_inside_weight, fg_num
def retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels,
is_crowd, im_info, lod, positive_overlap,
negative_overlap):
anchor_num = all_anchors.shape[0]
batch_size = len(lod) - 1
for i in range(batch_size):
im_scale = im_info[i][2]
inds_inside = np.arange(all_anchors.shape[0])
inside_anchors = all_anchors
b, e = lod[i], lod[i + 1]
gt_boxes_slice = gt_boxes[b:e, :] * im_scale
gt_labels_slice = gt_labels[b:e, :]
is_crowd_slice = is_crowd[b:e]
not_crowd_inds = np.where(is_crowd_slice == 0)[0]
gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
gt_labels_slice = gt_labels_slice[not_crowd_inds]
iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
loc_inds, score_inds, labels, gt_inds, bbox_inside_weight, fg_num = \
retinanet_target_assign(iou, gt_labels_slice,
positive_overlap, negative_overlap)
# unmap to all anchor
loc_inds = inds_inside[loc_inds]
score_inds = inds_inside[score_inds]
sampled_gt = gt_boxes_slice[gt_inds]
sampled_anchor = all_anchors[loc_inds]
box_deltas = _box_to_delta(sampled_anchor, sampled_gt, [1., 1., 1., 1.])
if i == 0:
loc_indexes = loc_inds
score_indexes = score_inds
tgt_labels = labels
tgt_bboxes = box_deltas
bbox_inside_weights = bbox_inside_weight
fg_nums = [[fg_num]]
else:
loc_indexes = np.concatenate(
[loc_indexes, loc_inds + i * anchor_num])
score_indexes = np.concatenate(
[score_indexes, score_inds + i * anchor_num])
tgt_labels = np.concatenate([tgt_labels, labels])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
bbox_inside_weights = np.vstack([bbox_inside_weights, \
bbox_inside_weight])
fg_nums = np.concatenate([fg_nums, [[fg_num]]])
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights, fg_nums
class TestRpnTargetAssignOp(OpTest): class TestRpnTargetAssignOp(OpTest):
def setUp(self): def setUp(self):
n, c, h, w = 2, 4, 14, 14 n, c, h, w = 2, 4, 14, 14
...@@ -234,5 +333,65 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -234,5 +333,65 @@ class TestRpnTargetAssignOp(OpTest):
self.check_output() self.check_output()
class TestRetinanetTargetAssignOp(OpTest):
def setUp(self):
n, c, h, w = 2, 4, 14, 14
all_anchors = get_anchor(n, c, h, w)
gt_num = 10
all_anchors = all_anchors.reshape(-1, 4)
anchor_num = all_anchors.shape[0]
images_shape = [[64, 64], [64, 64]]
groundtruth, lod = _generate_groundtruth(images_shape, 3, 4)
lod = [0, 4, 8]
im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)):
im_info[i, 0] = images_shape[i][0]
im_info[i, 1] = images_shape[i][1]
im_info[i, 2] = 0.8 #scale
gt_boxes = np.vstack([v['boxes'] for v in groundtruth])
is_crowd = np.hstack([v['is_crowd'] for v in groundtruth])
gt_labels = np.vstack([
v['gt_classes'].reshape(len(v['gt_classes']), 1)
for v in groundtruth
])
gt_labels = gt_labels.reshape(len(gt_labels), 1)
all_anchors = all_anchors.astype('float32')
gt_boxes = gt_boxes.astype('float32')
gt_labels = gt_labels.astype('int32')
positive_overlap = 0.5
negative_overlap = 0.4
loc_index, score_index, tgt_bbox, labels, bbox_inside_weights, fg_num = \
retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels, is_crowd,
im_info, lod, positive_overlap, negative_overlap)
labels = labels[:, np.newaxis]
self.op_type = "retinanet_target_assign"
self.inputs = {
'Anchor': all_anchors,
'GtBoxes': (gt_boxes, [[4, 4]]),
'GtLabels': (gt_labels, [[4, 4]]),
'IsCrowd': (is_crowd, [[4, 4]]),
'ImInfo': (im_info, [[1, 1]])
}
self.attrs = {
'positive_overlap': positive_overlap,
'negative_overlap': negative_overlap
}
self.outputs = {
'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32'),
'BBoxInsideWeight': bbox_inside_weights.astype('float32'),
'ForegroundNumber': fg_num.astype('int32')
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import copy
from op_test import OpTest
from paddle.fluid import core
def sigmoid_focal_loss_forward(x_data, label_data, fg_num_data, gamma, alpha,
num_classes):
x_data_t = copy.deepcopy(x_data)
out_data = copy.deepcopy(x_data)
x_width = len(x_data)
x_height = len(x_data[0, :])
x_data_t = x_data_t.flatten()
out_data = out_data.flatten()
for idx in range(len(x_data_t)):
x = x_data_t[idx]
a = int(idx / num_classes)
d = int(idx % num_classes)
label = label_data[a]
c_pos = float((int(label) == int(d + 1)))
c_neg = float(((int(label) != -1) & (int(label) != (d + 1))))
fg_num = max(fg_num_data, 1)
z_neg = (1.0 - alpha) / fg_num
z_pos = alpha / fg_num
p = 1. / (1. + math.exp(-x))
FLT_MIN = 1.175494351e-38
term_pos = math.pow((1. - p), gamma) * math.log(max(FLT_MIN, p))
term_neg = math.pow(p, gamma) * (
-1. * x * (x >= 0) - math.log(1. + math.exp(x - 2. * x * (x >= 0))))
out_data[idx] = 0.0
out_data[idx] += -c_pos * term_pos * z_pos
out_data[idx] += -c_neg * term_neg * z_neg
out_data = out_data.reshape(x_width, x_height)
return out_data
class TestSigmoidFocalLossOp1(OpTest):
def set_argument(self):
self.num_anchors = 10
self.num_classes = 10
self.gamma = 2.0
self.alpha = 0.25
def setUp(self):
self.set_argument()
dims = (self.num_anchors, self.num_classes)
X = np.random.standard_normal(dims).astype("float32")
L = np.random.randint(0, self.num_classes + 1,
(dims[0], 1)).astype("int32")
F = np.zeros(1)
F[0] = len(np.where(L > 0)[0])
F = F.astype("int32")
self.op_type = "sigmoid_focal_loss"
self.inputs = {
'X': X,
'Label': L,
'FgNum': F,
}
self.attrs = {
'gamma': self.gamma,
'alpha': self.alpha,
}
loss = sigmoid_focal_loss_forward(
self.inputs['X'], self.inputs['Label'], self.inputs['FgNum'],
self.gamma, self.alpha, self.num_classes)
self.outputs = {'Out': loss.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidFocalLossOp2(TestSigmoidFocalLossOp1):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.002)
class TestSigmoidFocalLossOp3(TestSigmoidFocalLossOp1):
def set_argument(self):
self.num_anchors = 200
self.num_classes = 10
self.gamma = 1.0
self.alpha = 0.5
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidFocalLossOp4(TestSigmoidFocalLossOp3):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.002)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册