未验证 提交 0aee1f00 编写于 作者: F FlyingQianMM 提交者: GitHub

add sigmoid focal loss operator for supporting retinanet (#17895)

* test=develop
add sigmoid_focal_loss for supporting retinanet

* test=develop
add test_layers

* test=develop
add API.spc

* test=develop
alter sigmoid_focal_loss_op.cc

* test=develop
alter detection.py

* test=develop
alter API.spec

* test=develop
alter round 1

* test=develop
alter simooid_focal_loss

* test=develop
alter sigmoid_focal_loss_op.cc

* test=develop
alter test_layers.py

* test=develop
alter paddle/fluid/API.spec

* test=develop
alter sigmoid_focal_loss_op.cu

* test=develop
alter paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc
上级 9e4b9d97
......@@ -349,6 +349,7 @@ paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box'
paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '6d5028fd09d01ab82d296adc0ea95aee'))
paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801'))
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
paddle.fluid.layers.sigmoid_focal_loss (ArgSpec(args=['x', 'label', 'fg_num', 'gamma', 'alpha'], varargs=None, keywords=None, defaults=(2, 0.25)), ('document', 'aeac6aae100173b3fc7f102cf3023a3d'))
paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'c0d00acf724691ff3480d4207036a722'))
......
......@@ -35,6 +35,7 @@ detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class SigmoidFocalLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto fg_dims = ctx->GetInputDim("FgNum");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"The last dimension of input(Label) should be 1.");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class SigmoidFocalLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto fg_dims = ctx->GetInputDim("FgNum");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape.");
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"The last dimension of input(Label) should be 1.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(dout_dims, 0, rank),
"Input(X) and Input(Out@Grad) shall have the same shape.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class SigmoidFocalLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D], "
"where N is the batch size and D is the number of classes "
"(excluding background). This input is a tensor of logits "
"computed by the previous operator.");
AddInput("Label",
"(Tensor, default Tensor<int>), a 2-D tensor with shape [N, 1]. "
"This input is a tensor of probabilistic labels.");
AddInput("FgNum",
"(Tensor, default Tensor<int>), a 1-D tensor with shape [1]. "
"This input is the number of foreground.");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D]. "
"This output is the focal loss.");
AddAttr<float>(
"gamma",
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
"easy and hard examples. "
"A float scalar with default value 2.0.")
.SetDefault(2.0);
AddAttr<float>(
"alpha",
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
"positive and negative examples. "
"A float scalar with default value 0.5.")
.SetDefault(0.25);
AddComment(R"DOC(
Sigmoid Focal Loss Operator.
Focal loss is used to address the foreground-background class imbalance existed
on the training phase of one-stage detectors. This operator computes the sigmoid
value for each element in the input tensor, after which focal loss is measured.
The focal loss is given as follows:
$$Loss_j = (-Label_j * alpha * \pow(1 - \sigma(X_j), gamma) * \log(\sigma(X_j)) -
(1 - Labels_j) * (1 - alpha) * \pow(\sigma(X_j), gamma) * \log(1 - \sigma(X_j)))
/ FgNum, j = 1,...,K$$
We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$.
)DOC");
}
};
class SigmoidFocalLossGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sigmoid_focal_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput("FgNum", Input("FgNum"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp,
ops::SigmoidFocalLossOpMaker,
ops::SigmoidFocalLossGradOpDescMaker);
REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp);
REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss,
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss_grad,
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GPUSigmoidFocalLossForward(const T *x_data,
const int *label_data,
const int *fg_num_data,
const T gamma, const T alpha,
const int num_classes,
const int limit, T *out_data) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
int a = i / num_classes; // current sample
int d = i % num_classes; // current class
int g = label_data[a]; // target
// check whether the input data is positive or negative
// the target classes are in range 1-81
// and the d is in range 0-80
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
// p = 1. / 1. + expf(-x)
T p = 1. / (1. + real_exp(-x));
// (1 - p)**gamma * log(p)
T term_pos =
std::pow((1. - p), gamma) * real_log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
T term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0))));
out_data[i] = 0.0;
out_data[i] += -c_pos * term_pos * s_pos;
out_data[i] += -c_neg * term_neg * s_neg;
}
}
template <typename T>
__global__ void GPUSigmoidFocalLossBackward(
const T *x_data, const int *label_data, const int *fg_num_data,
const T gamma, const T alpha, const int num_classes, const T *dout_data,
const int limit, T *dx_data) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T dout = dout_data[i];
int a = i / num_classes; // current sample
int d = i % num_classes; // current class
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
int g = label_data[a];
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T p = 1. / (1. + real_exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow((1. - p), gamma) *
(1. - p - (p * gamma * real_log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg =
std::pow(p, gamma) *
((-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[i] = 0.0;
dx_data[i] += -c_pos * s_pos * term_pos;
dx_data[i] += -c_neg * s_neg * term_neg;
dx_data[i] = dx_data[i] * dout;
}
}
template <typename DeviceContext, typename T>
class GPUSigmoidFocalLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
Tensor *Out = context.Output<Tensor>("Out");
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
auto out_data = Out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.cuda_device_context();
int limit = Out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidFocalLossForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
num_classes, limit, out_data);
}
};
template <typename DeviceContext, typename T>
class GPUSigmoidFocalLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
auto &dev_ctx = context.cuda_device_context();
int limit = dX->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidFocalLossBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
num_classes, dOut->data<T>(), limit, dx_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sigmoid_focal_loss,
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
sigmoid_focal_loss_grad,
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <limits>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SigmoidFocalLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
Tensor *Out = context.Output<Tensor>("Out");
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto out_data = Out->mutable_data<T>(context.GetPlace());
int limit = Out->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<int>();
auto fg_num_data = FgNum->data<int>();
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
int a = idx / num_classes; // current sample
int d = idx % num_classes; // current class
int g = label_data[a]; // target
// Check whether the input data is positive or negative
// The target classes are in range 1-81
// and the d is in range 0-80
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = (1.0 - alpha) / fg_num;
T s_pos = alpha / fg_num;
// p = 1. / 1. + expf(-x)
T p = 1. / (1. + std::exp(-x));
// (1 - p)**gamma * log(p) where
T term_pos =
std::pow((1. - p), gamma) * std::log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
float term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - std::log(1. + std::exp(x - 2. * x * (x >= 0))));
out_data[idx] = 0.0;
out_data[idx] += -c_pos * term_pos * s_pos;
out_data[idx] += -c_neg * term_neg * s_neg;
}
}
};
template <typename DeviceContext, typename T>
class SigmoidFocalLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *FgNum = context.Input<Tensor>("FgNum");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
T gamma = static_cast<T>(context.Attr<float>("gamma"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto x_dims = X->dims();
int num_classes = static_cast<int>(x_dims[1]);
int limit = dX->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<int>();
auto fg_num_data = FgNum->data<int>();
auto dout_data = dOut->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
int a = idx / num_classes; // current sample
int d = idx % num_classes; // current class
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
T s_neg = static_cast<T>((1.0 - alpha) / fg_num);
T s_pos = alpha / fg_num;
int g = label_data[a];
T c_pos = static_cast<T>(g == (d + 1));
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
T p = 1. / (1. + std::exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow((1. - p), gamma) *
(1. - p - (p * gamma * std::log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg = std::pow(p, gamma) *
((-1. * x * (x >= 0) -
std::log(1. + std::exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[idx] = 0.0;
dx_data[idx] += -c_pos * s_pos * term_pos;
dx_data[idx] += -c_neg * s_neg * term_neg;
dx_data[idx] = dx_data[idx] * dout_data[idx];
}
}
};
} // namespace operators
} // namespace paddle
......@@ -40,6 +40,7 @@ __all__ = [
'ssd_loss',
'rpn_target_assign',
'retinanet_target_assign',
'sigmoid_focal_loss',
'anchor_generator',
'roi_perspective_transform',
'generate_proposal_labels',
......@@ -368,6 +369,74 @@ def rpn_target_assign(bbox_pred,
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25):
"""
**Sigmoid Focal Loss Operator.**
Focal loss is used to address the foreground-background class imbalance existed
on the training phase of one-stage detectors. This operator computes the sigmoid
value for each element in the input tensor, after which focal loss is measured.
The focal loss is given as followed:
.. math::
loss_j = (-label_j * alpha * {(1 - \\sigma(x_j))}^{gamma} * \\log(\\sigma(x_j)) -
(1 - labels_j) * (1 - alpha) * {(\sigma(x_j)}^{ gamma} * \\log(1 - \\sigma(x_j)))
/ fg\_num, j = 1,...,K
We know that
.. math::
\\sigma(x_j) = \\frac{1}{1 + \\exp(-x_j)}
Args:
x(Variable): A 2-D tensor with shape [N, D], where N is the batch size and D is the number
of classes (excluding background). This input is a tensor of logits computed by the
previous operator.
label(Variable): A 2-D tensor with shape [N, 1], which is the probabilistic labels.
fg_num(Variable): A 1-D tensor with shape [1], which is the number of foreground.
gamma(float): Hyper-parameter to balance the easy and hard examples. Default value is
set to 2.0.
alpha(float): Hyper-parameter to balance the positive and negative example. Default value
is set to 0.25.
Returns:
out(Variable): A 2-D tensor with shape [N, D], which is the focal loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(
name='data', shape=[10,80], append_batch_size=False, dtype='float32')
label = fluid.layers.data(
name='label', shape=[10,1], append_batch_size=False, dtype='int32')
fg_num = fluid.layers.data(
name='fg_num', shape=[1], append_batch_size=False, dtype='int32')
loss = fluid.layers.sigmoid_focal_loss(x=input,
label=label,
fg_num=fg_num,
gamma=2.,
alpha=0.25)
"""
helper = LayerHelper("sigmoid_focal_loss", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="sigmoid_focal_loss",
inputs={"X": x,
"Label": label,
"FgNum": fg_num},
attrs={"gamma": gamma,
'alpha': alpha},
outputs={"Out": out})
return out
def detection_output(loc,
scores,
prior_box,
......
......@@ -2071,6 +2071,28 @@ class TestBook(LayerTest):
bbox_pred, cls_logits, anchor_box, anchor_var, gt_boxes,
gt_labels, is_crowd, im_info, 10))
def test_sigmoid_focal_loss(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='data',
shape=[10, 80],
append_batch_size=False,
dtype='float32')
label = layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='int32')
fg_num = layers.data(
name='fg_num',
shape=[1],
append_batch_size=False,
dtype='int32')
out = fluid.layers.sigmoid_focal_loss(
x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25)
return (out)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import copy
from op_test import OpTest
from paddle.fluid import core
def sigmoid_focal_loss_forward(x_data, label_data, fg_num_data, gamma, alpha,
num_classes):
x_data_t = copy.deepcopy(x_data)
out_data = copy.deepcopy(x_data)
x_width = len(x_data)
x_height = len(x_data[0, :])
x_data_t = x_data_t.flatten()
out_data = out_data.flatten()
for idx in range(len(x_data_t)):
x = x_data_t[idx]
a = int(idx / num_classes)
d = int(idx % num_classes)
label = label_data[a]
c_pos = float((int(label) == int(d + 1)))
c_neg = float(((int(label) != -1) & (int(label) != (d + 1))))
fg_num = max(fg_num_data, 1)
z_neg = (1.0 - alpha) / fg_num
z_pos = alpha / fg_num
p = 1. / (1. + math.exp(-x))
FLT_MIN = 1.175494351e-38
term_pos = math.pow((1. - p), gamma) * math.log(max(FLT_MIN, p))
term_neg = math.pow(p, gamma) * (
-1. * x * (x >= 0) - math.log(1. + math.exp(x - 2. * x * (x >= 0))))
out_data[idx] = 0.0
out_data[idx] += -c_pos * term_pos * z_pos
out_data[idx] += -c_neg * term_neg * z_neg
out_data = out_data.reshape(x_width, x_height)
return out_data
class TestSigmoidFocalLossOp1(OpTest):
def set_argument(self):
self.num_anchors = 10
self.num_classes = 10
self.gamma = 2.0
self.alpha = 0.25
def setUp(self):
self.set_argument()
dims = (self.num_anchors, self.num_classes)
X = np.random.standard_normal(dims).astype("float32")
L = np.random.randint(0, self.num_classes + 1,
(dims[0], 1)).astype("int32")
F = np.zeros(1)
F[0] = len(np.where(L > 0)[0])
F = F.astype("int32")
self.op_type = "sigmoid_focal_loss"
self.inputs = {
'X': X,
'Label': L,
'FgNum': F,
}
self.attrs = {
'gamma': self.gamma,
'alpha': self.alpha,
}
loss = sigmoid_focal_loss_forward(
self.inputs['X'], self.inputs['Label'], self.inputs['FgNum'],
self.gamma, self.alpha, self.num_classes)
self.outputs = {'Out': loss.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidFocalLossOp2(TestSigmoidFocalLossOp1):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.002)
class TestSigmoidFocalLossOp3(TestSigmoidFocalLossOp1):
def set_argument(self):
self.num_anchors = 200
self.num_classes = 10
self.gamma = 1.0
self.alpha = 0.5
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidFocalLossOp4(TestSigmoidFocalLossOp3):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.002)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册