未验证 提交 3305045c 编写于 作者: F FlyingQianMM 提交者: GitHub

Cherry pick retinanet_target_assign_op(#17893), sigmoid_focal_loss_op(#17895)...

Cherry pick retinanet_target_assign_op(#17893), sigmoid_focal_loss_op(#17895) and retinanet_detection_output_op(#17896) for supporting retinanet (#18141)

* test=release/1.5
Fix conflicts in test_layers.py when adding target assign operator for supporting retinanet. Cherry pick #17893

* test=release/1.5
Add sigmoid focal loss operator for supporting retinanet. Cherry pick #17895

* test=release/1.5
Add detection output operator for supporting retinanet. Cherry pick #17896

* test=release/1.5
fix wrong code style in test_layers.py when cherry pick retinanet_target_assign #17893

* test=release/1.5
Fix type error of std::pow in sigmoid_focal_loss. Cherry pick #17895
上级 7c7afef7
...@@ -348,6 +348,8 @@ paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box' ...@@ -348,6 +348,8 @@ paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box'
paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee')) paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee'))
paddle.fluid.layers.detection_map (ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')), ('document', '1467d91b50c22cd52103b4aa1ee9d0a1')) paddle.fluid.layers.detection_map (ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')), ('document', '1467d91b50c22cd52103b4aa1ee9d0a1'))
paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801')) paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801'))
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
paddle.fluid.layers.sigmoid_focal_loss (ArgSpec(args=['x', 'label', 'fg_num', 'gamma', 'alpha'], varargs=None, keywords=None, defaults=(2, 0.25)), ('document', 'aeac6aae100173b3fc7f102cf3023a3d'))
paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910')) paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910'))
...@@ -360,6 +362,7 @@ paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'ancho ...@@ -360,6 +362,7 @@ paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'ancho
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f332fb8c5bb581bd1a6b5be450a99990')) paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f332fb8c5bb581bd1a6b5be450a99990'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '04384378ff00a42ade8fabd52e27cbc5')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '04384378ff00a42ade8fabd52e27cbc5'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.retinanet_detection_output (ArgSpec(args=['bboxes', 'scores', 'anchors', 'im_info', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0.05, 1000, 100, 0.3, 1.0)), ('document', '078d28607ce261a0cba2b965a79f6bb8'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5')) paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5'))
paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace')) paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace'))
......
...@@ -35,6 +35,8 @@ detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) ...@@ -35,6 +35,8 @@ detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
if(WITH_GPU) if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class SigmoidFocalLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto fg_dims = ctx->GetInputDim("FgNum");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"The last dimension of input(Label) should be 1.");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class SigmoidFocalLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto fg_dims = ctx->GetInputDim("FgNum");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape.");
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"The last dimension of input(Label) should be 1.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(dout_dims, 0, rank),
"Input(X) and Input(Out@Grad) shall have the same shape.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class SigmoidFocalLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D], "
"where N is the batch size and D is the number of classes "
"(excluding background). This input is a tensor of logits "
"computed by the previous operator.");
AddInput("Label",
"(Tensor, default Tensor<int>), a 2-D tensor with shape [N, 1]. "
"This input is a tensor of probabilistic labels.");
AddInput("FgNum",
"(Tensor, default Tensor<int>), a 1-D tensor with shape [1]. "
"This input is the number of foreground.");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D]. "
"This output is the focal loss.");
AddAttr<float>(
"gamma",
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
"easy and hard examples. "
"A float scalar with default value 2.0.")
.SetDefault(2.0);
AddAttr<float>(
"alpha",
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
"positive and negative examples. "
"A float scalar with default value 0.5.")
.SetDefault(0.25);
AddComment(R"DOC(
Sigmoid Focal Loss Operator.
Focal loss is used to address the foreground-background class imbalance existed
on the training phase of one-stage detectors. This operator computes the sigmoid
value for each element in the input tensor, after which focal loss is measured.
The focal loss is given as follows:
$$Loss_j = (-Label_j * alpha * \pow(1 - \sigma(X_j), gamma) * \log(\sigma(X_j)) -
(1 - Labels_j) * (1 - alpha) * \pow(\sigma(X_j), gamma) * \log(1 - \sigma(X_j)))
/ FgNum, j = 1,...,K$$
We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$.
)DOC");
}
};
class SigmoidFocalLossGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sigmoid_focal_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput("FgNum", Input("FgNum"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp,
ops::SigmoidFocalLossOpMaker,
ops::SigmoidFocalLossGradOpDescMaker);
REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp);
REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss,
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss_grad,
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GPUSigmoidFocalLossForward(const T *x_data,
const int *label_data,
const int *fg_num_data,
const T gamma, const T alpha,
const int num_classes,
const int limit, T *out_data) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
int a = i / num_classes; // current sample
int d = i % num_classes; // current class
int g = label_data[a]; // target
// check whether the input data is positive or negative
// the target classes are in range 1-81
// and the d is in range 0-80
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
// p = 1. / 1. + expf(-x)
T p = 1. / (1. + real_exp(-x));
// (1 - p)**gamma * log(p)
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
real_log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
T term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0))));
out_data[i] = 0.0;
out_data[i] += -c_pos * term_pos * s_pos;
out_data[i] += -c_neg * term_neg * s_neg;
}
}
template <typename T>
__global__ void GPUSigmoidFocalLossBackward(
const T *x_data, const int *label_data, const int *fg_num_data,
const T gamma, const T alpha, const int num_classes, const T *dout_data,
const int limit, T *dx_data) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T dout = dout_data[i];
int a = i / num_classes; // current sample
int d = i % num_classes; // current class
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
int g = label_data[a];
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T p = 1. / (1. + real_exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
(1. - p - (p * gamma * real_log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg =
std::pow(p, gamma) *
((-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[i] = 0.0;
dx_data[i] += -c_pos * s_pos * term_pos;
dx_data[i] += -c_neg * s_neg * term_neg;
dx_data[i] = dx_data[i] * dout;
}
}
template <typename DeviceContext, typename T>
class GPUSigmoidFocalLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
Tensor *Out = context.Output<Tensor>("Out");
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
auto out_data = Out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.cuda_device_context();
int limit = Out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidFocalLossForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
num_classes, limit, out_data);
}
};
template <typename DeviceContext, typename T>
class GPUSigmoidFocalLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
auto &dev_ctx = context.cuda_device_context();
int limit = dX->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidFocalLossBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
num_classes, dOut->data<T>(), limit, dx_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sigmoid_focal_loss,
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
sigmoid_focal_loss_grad,
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <limits>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SigmoidFocalLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
Tensor *Out = context.Output<Tensor>("Out");
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto out_data = Out->mutable_data<T>(context.GetPlace());
int limit = Out->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<int>();
auto fg_num_data = FgNum->data<int>();
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
int a = idx / num_classes; // current sample
int d = idx % num_classes; // current class
int g = label_data[a]; // target
// Check whether the input data is positive or negative
// The target classes are in range 1-81
// and the d is in range 0-80
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
// p = 1. / 1. + expf(-x)
T p = 1. / (1. + std::exp(-x));
// (1 - p)**gamma * log(p) where
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
std::log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
T term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - std::log(1. + std::exp(x - 2. * x * (x >= 0))));
out_data[idx] = 0.0;
out_data[idx] += -c_pos * term_pos * s_pos;
out_data[idx] += -c_neg * term_neg * s_neg;
}
}
};
template <typename DeviceContext, typename T>
class SigmoidFocalLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
int limit = dX->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<int>();
auto fg_num_data = FgNum->data<int>();
auto dout_data = dOut->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
int a = idx / num_classes; // current sample
int d = idx % num_classes; // current class
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = static_cast<T>((1.0 - alpha) / fg_num);
T s_pos = alpha / fg_num;
int g = label_data[a];
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T p = 1. / (1. + std::exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
(1. - p - (p * gamma * std::log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg = std::pow(p, gamma) *
((-1. * x * (x >= 0) -
std::log(1. + std::exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[idx] = 0.0;
dx_data[idx] += -c_pos * s_pos * term_pos;
dx_data[idx] += -c_neg * s_neg * term_neg;
dx_data[idx] = dx_data[idx] * dout_data[idx];
}
}
};
} // namespace operators
} // namespace paddle
...@@ -40,6 +40,8 @@ __all__ = [ ...@@ -40,6 +40,8 @@ __all__ = [
'ssd_loss', 'ssd_loss',
'detection_map', 'detection_map',
'rpn_target_assign', 'rpn_target_assign',
'retinanet_target_assign',
'sigmoid_focal_loss',
'anchor_generator', 'anchor_generator',
'roi_perspective_transform', 'roi_perspective_transform',
'generate_proposal_labels', 'generate_proposal_labels',
...@@ -52,12 +54,171 @@ __all__ = [ ...@@ -52,12 +54,171 @@ __all__ = [
'yolo_box', 'yolo_box',
'box_clip', 'box_clip',
'multiclass_nms', 'multiclass_nms',
'retinanet_detection_output',
'distribute_fpn_proposals', 'distribute_fpn_proposals',
'box_decoder_and_assign', 'box_decoder_and_assign',
'collect_fpn_proposals', 'collect_fpn_proposals',
] ]
def retinanet_target_assign(bbox_pred,
cls_logits,
anchor_box,
anchor_var,
gt_boxes,
gt_labels,
is_crowd,
im_info,
num_classes=1,
positive_overlap=0.5,
negative_overlap=0.4):
"""
**Target Assign Layer for Retinanet .**
This layer can be, for given the Intersection-over-Union (IoU) overlap
between anchors and ground truth boxes, to assign classification and
regression targets to each anchor, these target labels are used for training
retinanet. Every anchor is assigned with a length :attr:`num_classes`
one-hot vector of classification targets, and a 4-vector of box regression
targets. The assignment rules are as followed:
1. Anchors are assigned to ground-truth boxes when: (i) it has the highest
IoU overlap with a ground-truth box, or (ii) it has an IoU overlap higher
than positive_overlap(0.5) with any ground-truth box.
2. Anchors are assigned to background when its IoU ratio is lower than
negative_overlap (0.4) for all ground-truth boxes.
When an anchor is assigned with a ground-truth box which is the i-th category,
the i-th entry in its C vector of targets is set to 1 and all other entries
are set to 0. When an anchor is assigned with background, all entries are set
to 0. Anchors that are not assigned do not contribute to the training
objective. The regression targets are the encoded ground-truth boxes
associated with the assigned anchors.
Args:
bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax].
cls_logits(Variable): A 3-D Tensor with shape [N, M, C] represents the
predicted confidence predictions. N is the batch size, C is the
number of classes (excluding background), M is number of bounding boxes.
anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box,
if the input is image feature map, they are close to the origin
of the coordinate system. [xmax, ymax] is the right bottom
coordinate of the anchor box.
anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded
variances of anchors.
gt_boxes(Variable): The ground-truth bounding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input.
gt_labels(variable): The ground-truth labels are a 2D LoDTensor with
shape [Ng, 1], Ng is the total number of ground-truth labels of
mini-batch input.
is_crowd(Variable): A 1-D LoDTensor which indicates ground-truth is crowd.
im_info(Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
3 is the height, width and scale.
num_classes(int32): The number of classes.
positive_overlap(float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be a positive
example.
negative_overlap(float): Maximum overlap allowed between an anchor
and ground-truth box for the (anchor, gt box) pair to be a negative
examples.
Returns:
tuple:
A tuple(predicted_scores, predicted_location, target_label,
target_bbox, bbox_inside_weight, fg_num) is returned. The
predicted_scores and predicted_location are the predicted result
of the retinanet.The target_label and target_bbox are the ground
truth, respectively. The predicted_location is a 2D Tensor with
shape [F, 4], and the shape of target_bbox is same as the shape of
the predicted_location, F is the number of the foreground
anchors. The predicted_scores is a 2D Tensor with shape
[F + B, C], and the shape of target_label is [F + B, 1], B is the
number of the background anchors, the F and B is depends on the
input of this operator. Bbox_inside_weight represents whether the
predicted location is fake foreground or not and the shape is [F, 4].
Fg_num is the foreground number (including fake foreground) which
is needed by focal loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
bbox_pred = layers.data(name='bbox_pred', shape=[1, 100, 4],
append_batch_size=False, dtype='float32')
cls_logits = layers.data(name='cls_logits', shape=[1, 100, 10],
append_batch_size=False, dtype='float32')
anchor_box = layers.data(name='anchor_box', shape=[100, 4],
append_batch_size=False, dtype='float32')
anchor_var = layers.data(name='anchor_var', shape=[100, 4],
append_batch_size=False, dtype='float32')
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32')
gt_labels = layers.data(name='gt_labels', shape=[10, 1],
append_batch_size=False, dtype='float32')
is_crowd = fluid.layers.data(name='is_crowd', shape=[1],
append_batch_size=False, dtype='float32')
im_info = fluid.layers.data(name='im_infoss', shape=[1, 3],
append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target, bbox_inside_weight, fg_num =
fluid.layers.retinanet_target_assign(bbox_pred, cls_logits, anchor_box,
anchor_var, gt_boxes, gt_labels, is_crowd, im_info, 10)
"""
helper = LayerHelper('retinanet_target_assign', **locals())
# Assign target label to anchors
loc_index = helper.create_variable_for_type_inference(dtype='int32')
score_index = helper.create_variable_for_type_inference(dtype='int32')
target_label = helper.create_variable_for_type_inference(dtype='int32')
target_bbox = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
bbox_inside_weight = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
fg_num = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type="retinanet_target_assign",
inputs={
'Anchor': anchor_box,
'GtBoxes': gt_boxes,
'GtLabels': gt_labels,
'IsCrowd': is_crowd,
'ImInfo': im_info
},
outputs={
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': target_label,
'TargetBBox': target_bbox,
'BBoxInsideWeight': bbox_inside_weight,
'ForegroundNumber': fg_num
},
attrs={
'positive_overlap': positive_overlap,
'negative_overlap': negative_overlap
})
loc_index.stop_gradient = True
score_index.stop_gradient = True
target_label.stop_gradient = True
target_bbox.stop_gradient = True
bbox_inside_weight.stop_gradient = True
fg_num.stop_gradient = True
cls_logits = nn.reshape(x=cls_logits, shape=(-1, num_classes))
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight, fg_num
def rpn_target_assign(bbox_pred, def rpn_target_assign(bbox_pred,
cls_logits, cls_logits,
anchor_box, anchor_box,
...@@ -210,6 +371,74 @@ def rpn_target_assign(bbox_pred, ...@@ -210,6 +371,74 @@ def rpn_target_assign(bbox_pred,
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25):
"""
**Sigmoid Focal Loss Operator.**
Focal loss is used to address the foreground-background class imbalance existed
on the training phase of one-stage detectors. This operator computes the sigmoid
value for each element in the input tensor, after which focal loss is measured.
The focal loss is given as followed:
.. math::
loss_j = (-label_j * alpha * {(1 - \\sigma(x_j))}^{gamma} * \\log(\\sigma(x_j)) -
(1 - labels_j) * (1 - alpha) * {(\sigma(x_j)}^{ gamma} * \\log(1 - \\sigma(x_j)))
/ fg\_num, j = 1,...,K
We know that
.. math::
\\sigma(x_j) = \\frac{1}{1 + \\exp(-x_j)}
Args:
x(Variable): A 2-D tensor with shape [N, D], where N is the batch size and D is the number
of classes (excluding background). This input is a tensor of logits computed by the
previous operator.
label(Variable): A 2-D tensor with shape [N, 1], which is the probabilistic labels.
fg_num(Variable): A 1-D tensor with shape [1], which is the number of foreground.
gamma(float): Hyper-parameter to balance the easy and hard examples. Default value is
set to 2.0.
alpha(float): Hyper-parameter to balance the positive and negative example. Default value
is set to 0.25.
Returns:
out(Variable): A 2-D tensor with shape [N, D], which is the focal loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(
name='data', shape=[10,80], append_batch_size=False, dtype='float32')
label = fluid.layers.data(
name='label', shape=[10,1], append_batch_size=False, dtype='int32')
fg_num = fluid.layers.data(
name='fg_num', shape=[1], append_batch_size=False, dtype='int32')
loss = fluid.layers.sigmoid_focal_loss(x=input,
label=label,
fg_num=fg_num,
gamma=2.,
alpha=0.25)
"""
helper = LayerHelper("sigmoid_focal_loss", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="sigmoid_focal_loss",
inputs={"X": x,
"Label": label,
"FgNum": fg_num},
attrs={"gamma": gamma,
'alpha': alpha},
outputs={"Out": out})
return out
def detection_output(loc, def detection_output(loc,
scores, scores,
prior_box, prior_box,
...@@ -2320,6 +2549,113 @@ def box_clip(input, im_info, name=None): ...@@ -2320,6 +2549,113 @@ def box_clip(input, im_info, name=None):
return output return output
def retinanet_detection_output(bboxes,
scores,
anchors,
im_info,
score_threshold=0.05,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
nms_eta=1.):
"""
**Detection Output Layer for Retinanet.**
This operation is to get the detection results by performing following
steps:
1. Decode top-scoring bounding box predictions per FPN level according
to the anchor boxes.
2. Merge top predictions from all levels and apply multi-class non
maximum suppression (NMS) on them to get the final detections.
Args:
bboxes(List): A list of tensors from multiple FPN levels. Each
element is a 3-D Tensor with shape [N, Mi, 4] representing the
predicted locations of Mi bounding boxes. N is the batch size,
Mi is the number of bounding boxes from i-th FPN level and each
bounding box has four coordinate values and the layout is
[xmin, ymin, xmax, ymax].
scores(List): A list of tensors from multiple FPN levels. Each
element is a 3-D Tensor with shape [N, Mi, C] representing the
predicted confidence predictions. N is the batch size, C is the
class number (excluding background), Mi is the number of bounding
boxes from i-th FPN level. For each bounding box, there are total
C scores.
anchors(List): A 2-D Tensor with shape [Mi, 4] represents the locations
of Mi anchor boxes from all FPN level. Each bounding box has four
coordinate values and the layout is [xmin, ymin, xmax, ymax].
im_info(Variable): A 2-D LoDTensor with shape [N, 3] represents the
image information. N is the batch size, each image information
includes height, width and scale.
score_threshold(float): Threshold to filter out bounding boxes
with a confidence score.
nms_top_k(int): Maximum number of detections per FPN layer to be
kept according to the confidences before NMS.
keep_top_k(int): Number of total bounding boxes to be kept per image after
NMS step. -1 means keeping all bounding boxes after NMS step.
nms_threshold(float): The threshold to be used in NMS.
nms_eta(float): The parameter for adaptive NMS.
Returns:
Variable:
The detection output is a LoDTensor with shape [No, 6].
Each row has six values: [label, confidence, xmin, ymin, xmax, ymax].
`No` is the total number of detections in this mini-batch. For each
instance, the offsets in first dimension are called LoD, the offset
number is N + 1, N is the batch size. The i-th image has
`LoD[i + 1] - LoD[i]` detected results, if it is 0, the i-th image
has no detected results. If all images have no detected results,
LoD will be set to 0, and the output tensor is empty (None).
Examples:
.. code-block:: python
import paddle.fluid as fluid
bboxes = layers.data(name='bboxes', shape=[1, 21, 4],
append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[1, 21, 10],
append_batch_size=False, dtype='float32')
anchors = layers.data(name='anchors', shape=[21, 4],
append_batch_size=False, dtype='float32')
im_info = layers.data(name="im_info", shape=[1, 3],
append_batch_size=False, dtype='float32')
nmsed_outs = fluid.layers.retinanet_detection_output(
bboxes=[bboxes, bboxes],
scores=[scores, scores],
anchors=[anchors, anchors],
im_info=im_info,
score_threshold=0.05,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
nms_eta=1.)
"""
helper = LayerHelper('retinanet_detection_output', **locals())
output = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('scores'))
helper.append_op(
type="retinanet_detection_output",
inputs={
'BBoxes': bboxes,
'Scores': scores,
'Anchors': anchors,
'ImInfo': im_info
},
attrs={
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'keep_top_k': keep_top_k,
'nms_eta': 1.,
},
outputs={'Out': output})
output.stop_gradient = True
return output
def multiclass_nms(bboxes, def multiclass_nms(bboxes,
scores, scores,
score_threshold, score_threshold,
......
...@@ -2018,6 +2018,110 @@ class TestBook(LayerTest): ...@@ -2018,6 +2018,110 @@ class TestBook(LayerTest):
trans_std=0.1) trans_std=0.1)
return (out) return (out)
def test_retinanet_target_assign(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
bbox_pred = layers.data(
name='bbox_pred',
shape=[1, 100, 4],
append_batch_size=False,
dtype='float32')
cls_logits = layers.data(
name='cls_logits',
shape=[1, 100, 10],
append_batch_size=False,
dtype='float32')
anchor_box = layers.data(
name='anchor_box',
shape=[100, 4],
append_batch_size=False,
dtype='float32')
anchor_var = layers.data(
name='anchor_var',
shape=[100, 4],
append_batch_size=False,
dtype='float32')
gt_boxes = layers.data(
name='gt_boxes',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
gt_labels = layers.data(
name='gt_labels',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[1],
append_batch_size=False,
dtype='float32')
im_info = layers.data(
name='im_info',
shape=[1, 3],
append_batch_size=False,
dtype='float32')
return (layers.retinanet_target_assign(
bbox_pred, cls_logits, anchor_box, anchor_var, gt_boxes,
gt_labels, is_crowd, im_info, 10))
def test_sigmoid_focal_loss(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='data',
shape=[10, 80],
append_batch_size=False,
dtype='float32')
label = layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='int32')
fg_num = layers.data(
name='fg_num',
shape=[1],
append_batch_size=False,
dtype='int32')
out = fluid.layers.sigmoid_focal_loss(
x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25)
return (out)
def test_retinanet_detection_output(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
bboxes = layers.data(
name='bboxes',
shape=[1, 21, 4],
append_batch_size=False,
dtype='float32')
scores = layers.data(
name='scores',
shape=[1, 21, 10],
append_batch_size=False,
dtype='float32')
anchors = layers.data(
name='anchors',
shape=[21, 4],
append_batch_size=False,
dtype='float32')
im_info = layers.data(
name="im_info",
shape=[1, 3],
append_batch_size=False,
dtype='float32')
nmsed_outs = layers.retinanet_detection_output(
bboxes=[bboxes, bboxes],
scores=[scores, scores],
anchors=[anchors, anchors],
im_info=im_info,
score_threshold=0.05,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
nms_eta=1.)
return (nmsed_outs)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
#Licensed under the Apache License, Version 2.0 (the "License")
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import copy
from op_test import OpTest
from test_anchor_generator_op import anchor_generator_in_python
from test_multiclass_nms_op import iou
from test_multiclass_nms_op import nms
def multiclass_nms(prediction, class_num, keep_top_k, nms_threshold):
selected_indices = {}
num_det = 0
for c in range(class_num):
if c not in prediction.keys():
continue
cls_dets = prediction[c]
all_scores = np.zeros(len(cls_dets))
for i in range(all_scores.shape[0]):
all_scores[i] = cls_dets[i][4]
indices = nms(cls_dets, all_scores, 0.0, nms_threshold, -1, False, 1.0)
selected_indices[c] = indices
num_det += len(indices)
score_index = []
for c, indices in selected_indices.items():
for idx in indices:
score_index.append((prediction[c][idx][4], c, idx))
sorted_score_index = sorted(
score_index, key=lambda tup: tup[0], reverse=True)
if keep_top_k > -1 and num_det > keep_top_k:
sorted_score_index = sorted_score_index[:keep_top_k]
num_det = keep_top_k
nmsed_outs = []
for s, c, idx in sorted_score_index:
xmin = prediction[c][idx][0]
ymin = prediction[c][idx][1]
xmax = prediction[c][idx][2]
ymax = prediction[c][idx][3]
nmsed_outs.append([c + 1, s, xmin, ymin, xmax, ymax])
return nmsed_outs, num_det
def retinanet_detection_out(boxes_list, scores_list, anchors_list, im_info,
score_threshold, nms_threshold, nms_top_k,
keep_top_k):
class_num = scores_list[0].shape[-1]
im_height, im_width, im_scale = im_info
num_level = len(scores_list)
prediction = {}
for lvl in range(num_level):
scores_per_level = scores_list[lvl]
scores_per_level = scores_per_level.flatten()
bboxes_per_level = boxes_list[lvl]
bboxes_per_level = bboxes_per_level.flatten()
anchors_per_level = anchors_list[lvl]
anchors_per_level = anchors_per_level.flatten()
thresh = score_threshold if lvl < (num_level - 1) else 0.0
selected_indices = np.argwhere(scores_per_level > thresh)
scores = scores_per_level[selected_indices]
sorted_indices = np.argsort(-scores, axis=0, kind='mergesort')
if nms_top_k > -1 and nms_top_k < sorted_indices.shape[0]:
sorted_indices = sorted_indices[:nms_top_k]
for i in range(sorted_indices.shape[0]):
idx = selected_indices[sorted_indices[i]]
idx = idx[0][0]
a = int(idx / class_num)
c = int(idx % class_num)
box_offset = a * 4
anchor_box_width = anchors_per_level[
box_offset + 2] - anchors_per_level[box_offset] + 1
anchor_box_height = anchors_per_level[
box_offset + 3] - anchors_per_level[box_offset + 1] + 1
anchor_box_center_x = anchors_per_level[
box_offset] + anchor_box_width / 2
anchor_box_center_y = anchors_per_level[box_offset +
1] + anchor_box_height / 2
target_box_center_x = bboxes_per_level[
box_offset] * anchor_box_width + anchor_box_center_x
target_box_center_y = bboxes_per_level[
box_offset + 1] * anchor_box_height + anchor_box_center_y
target_box_width = math.exp(bboxes_per_level[box_offset +
2]) * anchor_box_width
target_box_height = math.exp(bboxes_per_level[
box_offset + 3]) * anchor_box_height
pred_box_xmin = target_box_center_x - target_box_width / 2
pred_box_ymin = target_box_center_y - target_box_height / 2
pred_box_xmax = target_box_center_x + target_box_width / 2 - 1
pred_box_ymax = target_box_center_y + target_box_height / 2 - 1
pred_box_xmin = pred_box_xmin / im_scale
pred_box_ymin = pred_box_ymin / im_scale
pred_box_xmax = pred_box_xmax / im_scale
pred_box_ymax = pred_box_ymax / im_scale
pred_box_xmin = max(
min(pred_box_xmin, np.round(im_width / im_scale) - 1), 0.)
pred_box_ymin = max(
min(pred_box_ymin, np.round(im_height / im_scale) - 1), 0.)
pred_box_xmax = max(
min(pred_box_xmax, np.round(im_width / im_scale) - 1), 0.)
pred_box_ymax = max(
min(pred_box_ymax, np.round(im_height / im_scale) - 1), 0.)
if c not in prediction.keys():
prediction[c] = []
prediction[c].append([
pred_box_xmin, pred_box_ymin, pred_box_xmax, pred_box_ymax,
scores_per_level[idx]
])
nmsed_outs, nmsed_num = multiclass_nms(prediction, class_num, keep_top_k,
nms_threshold)
return nmsed_outs, nmsed_num
def batched_retinanet_detection_out(boxes, scores, anchors, im_info,
score_threshold, nms_threshold, nms_top_k,
keep_top_k):
batch_size = scores[0].shape[0]
det_outs = []
lod = []
for n in range(batch_size):
boxes_per_batch = []
scores_per_batch = []
num_level = len(scores)
for lvl in range(num_level):
boxes_per_batch.append(boxes[lvl][n])
scores_per_batch.append(scores[lvl][n])
nmsed_outs, nmsed_num = retinanet_detection_out(
boxes_per_batch, scores_per_batch, anchors, im_info[n],
score_threshold, nms_threshold, nms_top_k, keep_top_k)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
det_outs.extend(nmsed_outs)
return det_outs, lod
class TestRetinanetDetectionOutOp1(OpTest):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
def init_test_input(self):
anchor_num = len(self.aspect_ratios) * self.scales_per_octave
num_levels = self.max_level - self.min_level + 1
self.scores_list = []
self.bboxes_list = []
self.anchors_list = []
for i in range(num_levels):
layer_h = self.layer_h[i]
layer_w = self.layer_w[i]
input_feat = np.random.random((self.batch_size, self.input_channels,
layer_h, layer_w)).astype('float32')
score = np.random.random(
(self.batch_size, self.class_num * anchor_num, layer_h,
layer_w)).astype('float32')
score = np.transpose(score, [0, 2, 3, 1])
score = score.reshape((self.batch_size, -1, self.class_num))
box = np.random.random((self.batch_size, self.box_size * anchor_num,
layer_h, layer_w)).astype('float32')
box = np.transpose(box, [0, 2, 3, 1])
box = box.reshape((self.batch_size, -1, self.box_size))
anchor_sizes = []
for octave in range(self.scales_per_octave):
anchor_sizes.append(
float(self.anchor_strides[i] * (2**octave)) /
float(self.scales_per_octave) * self.anchor_scale)
anchor, var = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=anchor_sizes,
aspect_ratios=self.aspect_ratios,
variances=[1.0, 1.0, 1.0, 1.0],
stride=[self.anchor_strides[i], self.anchor_strides[i]],
offset=0.5)
anchor = np.reshape(anchor, [-1, 4])
self.scores_list.append(score.astype('float32'))
self.bboxes_list.append(box.astype('float32'))
self.anchors_list.append(anchor.astype('float32'))
self.im_info = np.array([[256., 256., 1.5]]).astype(
'float32') #im_height, im_width, scale
def setUp(self):
self.set_argument()
self.init_test_input()
nmsed_outs, lod = batched_retinanet_detection_out(
self.bboxes_list, self.scores_list, self.anchors_list, self.im_info,
self.score_threshold, self.nms_threshold, self.nms_top_k,
self.keep_top_k)
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'retinanet_detection_output'
self.inputs = {
'BBoxes': [('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]),
('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3]),
('b4', self.bboxes_list[4])],
'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]),
('s2', self.scores_list[2]), ('s3', self.scores_list[3]),
('s4', self.scores_list[4])],
'Anchors':
[('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]),
('a2', self.anchors_list[2]), ('a3', self.anchors_list[3]),
('a4', self.anchors_list[4])],
'ImInfo': (self.im_info, [[1, ]])
}
self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = {
'score_threshold': self.score_threshold,
'nms_top_k': self.nms_top_k,
'nms_threshold': self.nms_threshold,
'keep_top_k': self.keep_top_k,
'nms_eta': 1.,
}
def test_check_output(self):
self.check_output()
class TestRetinanetDetectionOutOp2(OpTest):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
# Here test the case there the shape of each FPN level
# is irrelevant.
self.layer_h = [1, 4, 8, 8, 16]
self.layer_w = [1, 4, 8, 8, 16]
class TestRetinanetDetectionOutOpNo3(TestRetinanetDetectionOutOp1):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
class TestRetinanetDetectionOutOpNo4(TestRetinanetDetectionOutOp1):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 2
self.max_level = 5
self.nms_threshold = 0.3
self.nms_top_k = 1000
self.keep_top_k = 200
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
def setUp(self):
self.set_argument()
self.init_test_input()
nmsed_outs, lod = batched_retinanet_detection_out(
self.bboxes_list, self.scores_list, self.anchors_list, self.im_info,
self.score_threshold, self.nms_threshold, self.nms_top_k,
self.keep_top_k)
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'retinanet_detection_output'
self.inputs = {
'BBoxes':
[('b0', self.bboxes_list[0]), ('b1', self.bboxes_list[1]),
('b2', self.bboxes_list[2]), ('b3', self.bboxes_list[3])],
'Scores': [('s0', self.scores_list[0]), ('s1', self.scores_list[1]),
('s2', self.scores_list[2]),
('s3', self.scores_list[3])],
'Anchors':
[('a0', self.anchors_list[0]), ('a1', self.anchors_list[1]),
('a2', self.anchors_list[2]), ('a3', self.anchors_list[3])],
'ImInfo': (self.im_info, [[1, ]])
}
self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = {
'score_threshold': self.score_threshold,
'nms_top_k': self.nms_top_k,
'nms_threshold': self.nms_threshold,
'keep_top_k': self.keep_top_k,
'nms_eta': 1.,
}
def test_check_output(self):
self.check_output()
class TestRetinanetDetectionOutOpNo5(TestRetinanetDetectionOutOp1):
def set_argument(self):
self.score_threshold = 0.05
self.min_level = 3
self.max_level = 7
self.nms_threshold = 0.3
self.nms_top_k = 100
self.keep_top_k = 10
self.scales_per_octave = 3
self.aspect_ratios = [1.0, 2.0, 0.5]
self.anchor_scale = 4
self.anchor_strides = [8, 16, 32, 64, 128]
self.box_size = 4
self.class_num = 80
self.batch_size = 1
self.input_channels = 20
self.layer_h = []
self.layer_w = []
num_levels = self.max_level - self.min_level + 1
for i in range(num_levels):
self.layer_h.append(2**(num_levels - i))
self.layer_w.append(2**(num_levels - i))
if __name__ == '__main__':
unittest.main()
...@@ -167,6 +167,105 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -167,6 +167,105 @@ def rpn_target_assign_in_python(all_anchors,
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights
def retinanet_target_assign(anchor_by_gt_overlap, gt_labels, positive_overlap,
negative_overlap):
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1)
anchor_to_gt_max = anchor_by_gt_overlap[np.arange(
anchor_by_gt_overlap.shape[0]), anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0)
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, np.arange(
anchor_by_gt_overlap.shape[1])]
anchors_with_max_overlap = np.where(
anchor_by_gt_overlap == gt_to_anchor_max)[0]
labels = np.ones((anchor_by_gt_overlap.shape[0], ), dtype=np.int32) * -1
labels[anchors_with_max_overlap] = 1
labels[anchor_to_gt_max >= positive_overlap] = 1
fg_inds = np.where(labels == 1)[0]
bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32)
bg_inds = np.where(anchor_to_gt_max < negative_overlap)[0]
enable_inds = bg_inds
fg_fake_inds = np.array([], np.int32)
fg_value = np.array([fg_inds[0]], np.int32)
fake_num = 0
for bg_id in enable_inds:
if bg_id in fg_inds:
fake_num += 1
fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
labels[enable_inds] = 0
bbox_inside_weight[fake_num:, :] = 1
fg_inds = np.where(labels == 1)[0]
bg_inds = np.where(labels == 0)[0]
loc_index = np.hstack([fg_fake_inds, fg_inds])
score_index = np.hstack([fg_inds, bg_inds])
score_index_tmp = np.hstack([fg_inds])
labels = labels[score_index]
gt_inds = anchor_to_gt_argmax[loc_index]
label_inds = anchor_to_gt_argmax[score_index_tmp]
labels[0:len(fg_inds)] = np.squeeze(gt_labels[label_inds])
fg_num = len(fg_fake_inds) + len(fg_inds) + 1
assert not np.any(labels == -1), "Wrong labels with -1"
return loc_index, score_index, labels, gt_inds, bbox_inside_weight, fg_num
def retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels,
is_crowd, im_info, lod, positive_overlap,
negative_overlap):
anchor_num = all_anchors.shape[0]
batch_size = len(lod) - 1
for i in range(batch_size):
im_scale = im_info[i][2]
inds_inside = np.arange(all_anchors.shape[0])
inside_anchors = all_anchors
b, e = lod[i], lod[i + 1]
gt_boxes_slice = gt_boxes[b:e, :] * im_scale
gt_labels_slice = gt_labels[b:e, :]
is_crowd_slice = is_crowd[b:e]
not_crowd_inds = np.where(is_crowd_slice == 0)[0]
gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
gt_labels_slice = gt_labels_slice[not_crowd_inds]
iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
loc_inds, score_inds, labels, gt_inds, bbox_inside_weight, fg_num = \
retinanet_target_assign(iou, gt_labels_slice,
positive_overlap, negative_overlap)
# unmap to all anchor
loc_inds = inds_inside[loc_inds]
score_inds = inds_inside[score_inds]
sampled_gt = gt_boxes_slice[gt_inds]
sampled_anchor = all_anchors[loc_inds]
box_deltas = _box_to_delta(sampled_anchor, sampled_gt, [1., 1., 1., 1.])
if i == 0:
loc_indexes = loc_inds
score_indexes = score_inds
tgt_labels = labels
tgt_bboxes = box_deltas
bbox_inside_weights = bbox_inside_weight
fg_nums = [[fg_num]]
else:
loc_indexes = np.concatenate(
[loc_indexes, loc_inds + i * anchor_num])
score_indexes = np.concatenate(
[score_indexes, score_inds + i * anchor_num])
tgt_labels = np.concatenate([tgt_labels, labels])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
bbox_inside_weights = np.vstack([bbox_inside_weights, \
bbox_inside_weight])
fg_nums = np.concatenate([fg_nums, [[fg_num]]])
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights, fg_nums
class TestRpnTargetAssignOp(OpTest): class TestRpnTargetAssignOp(OpTest):
def setUp(self): def setUp(self):
n, c, h, w = 2, 4, 14, 14 n, c, h, w = 2, 4, 14, 14
...@@ -234,5 +333,65 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -234,5 +333,65 @@ class TestRpnTargetAssignOp(OpTest):
self.check_output() self.check_output()
class TestRetinanetTargetAssignOp(OpTest):
def setUp(self):
n, c, h, w = 2, 4, 14, 14
all_anchors = get_anchor(n, c, h, w)
gt_num = 10
all_anchors = all_anchors.reshape(-1, 4)
anchor_num = all_anchors.shape[0]
images_shape = [[64, 64], [64, 64]]
groundtruth, lod = _generate_groundtruth(images_shape, 3, 4)
lod = [0, 4, 8]
im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)):
im_info[i, 0] = images_shape[i][0]
im_info[i, 1] = images_shape[i][1]
im_info[i, 2] = 0.8 #scale
gt_boxes = np.vstack([v['boxes'] for v in groundtruth])
is_crowd = np.hstack([v['is_crowd'] for v in groundtruth])
gt_labels = np.vstack([
v['gt_classes'].reshape(len(v['gt_classes']), 1)
for v in groundtruth
])
gt_labels = gt_labels.reshape(len(gt_labels), 1)
all_anchors = all_anchors.astype('float32')
gt_boxes = gt_boxes.astype('float32')
gt_labels = gt_labels.astype('int32')
positive_overlap = 0.5
negative_overlap = 0.4
loc_index, score_index, tgt_bbox, labels, bbox_inside_weights, fg_num = \
retinanet_target_assign_in_python(all_anchors, gt_boxes, gt_labels, is_crowd,
im_info, lod, positive_overlap, negative_overlap)
labels = labels[:, np.newaxis]
self.op_type = "retinanet_target_assign"
self.inputs = {
'Anchor': all_anchors,
'GtBoxes': (gt_boxes, [[4, 4]]),
'GtLabels': (gt_labels, [[4, 4]]),
'IsCrowd': (is_crowd, [[4, 4]]),
'ImInfo': (im_info, [[1, 1]])
}
self.attrs = {
'positive_overlap': positive_overlap,
'negative_overlap': negative_overlap
}
self.outputs = {
'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32'),
'BBoxInsideWeight': bbox_inside_weights.astype('float32'),
'ForegroundNumber': fg_num.astype('int32')
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import copy
from op_test import OpTest
from paddle.fluid import core
def sigmoid_focal_loss_forward(x_data, label_data, fg_num_data, gamma, alpha,
num_classes):
x_data_t = copy.deepcopy(x_data)
out_data = copy.deepcopy(x_data)
x_width = len(x_data)
x_height = len(x_data[0, :])
x_data_t = x_data_t.flatten()
out_data = out_data.flatten()
for idx in range(len(x_data_t)):
x = x_data_t[idx]
a = int(idx / num_classes)
d = int(idx % num_classes)
label = label_data[a]
c_pos = float((int(label) == int(d + 1)))
c_neg = float(((int(label) != -1) & (int(label) != (d + 1))))
fg_num = max(fg_num_data, 1)
z_neg = (1.0 - alpha) / fg_num
z_pos = alpha / fg_num
p = 1. / (1. + math.exp(-x))
FLT_MIN = 1.175494351e-38
term_pos = math.pow((1. - p), gamma) * math.log(max(FLT_MIN, p))
term_neg = math.pow(p, gamma) * (
-1. * x * (x >= 0) - math.log(1. + math.exp(x - 2. * x * (x >= 0))))
out_data[idx] = 0.0
out_data[idx] += -c_pos * term_pos * z_pos
out_data[idx] += -c_neg * term_neg * z_neg
out_data = out_data.reshape(x_width, x_height)
return out_data
class TestSigmoidFocalLossOp1(OpTest):
def set_argument(self):
self.num_anchors = 10
self.num_classes = 10
self.gamma = 2.0
self.alpha = 0.25
def setUp(self):
self.set_argument()
dims = (self.num_anchors, self.num_classes)
X = np.random.standard_normal(dims).astype("float32")
L = np.random.randint(0, self.num_classes + 1,
(dims[0], 1)).astype("int32")
F = np.zeros(1)
F[0] = len(np.where(L > 0)[0])
F = F.astype("int32")
self.op_type = "sigmoid_focal_loss"
self.inputs = {
'X': X,
'Label': L,
'FgNum': F,
}
self.attrs = {
'gamma': self.gamma,
'alpha': self.alpha,
}
loss = sigmoid_focal_loss_forward(
self.inputs['X'], self.inputs['Label'], self.inputs['FgNum'],
self.gamma, self.alpha, self.num_classes)
self.outputs = {'Out': loss.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidFocalLossOp2(TestSigmoidFocalLossOp1):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.002)
class TestSigmoidFocalLossOp3(TestSigmoidFocalLossOp1):
def set_argument(self):
self.num_anchors = 200
self.num_classes = 10
self.gamma = 1.0
self.alpha = 0.5
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidFocalLossOp4(TestSigmoidFocalLossOp3):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.002)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册