未验证 提交 92b9ce34 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #16073 from heavengate/yolov3_loss_imporve

Yolov3 loss: add mixup score and label smooth
...@@ -330,7 +330,7 @@ paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', ...@@ -330,7 +330,7 @@ paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes',
paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '587845f60c5d97ffdf2dfd21da52eca1')) paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '587845f60c5d97ffdf2dfd21da52eca1'))
paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e')) paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e'))
paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b')) paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691')) paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detection/yolov3_loss_op.h" #include "paddle/fluid/operators/detection/yolov3_loss_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -72,6 +73,18 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -72,6 +73,18 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(class_num, 0, PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater then 0."); "Attr(class_num) should be an integer greater then 0.");
if (ctx->HasInput("GTScore")) {
auto dim_gtscore = ctx->GetInputDim("GTScore");
PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2,
"Input(GTScore) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(
dim_gtscore[0], dim_gtbox[0],
"Input(GTBox) and Input(GTScore) dim[0] should be same");
PADDLE_ENFORCE_EQ(
dim_gtscore[1], dim_gtbox[1],
"Input(GTBox) and Input(GTScore) dim[1] should be same");
}
std::vector<int64_t> dim_out({dim_x[0]}); std::vector<int64_t> dim_out({dim_x[0]});
ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out));
...@@ -112,6 +125,12 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,6 +125,12 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"This is a 2-D tensor with shape of [N, max_box_num], " "This is a 2-D tensor with shape of [N, max_box_num], "
"and each element should be an integer to indicate the " "and each element should be an integer to indicate the "
"box class id."); "box class id.");
AddInput("GTScore",
"The score of GTLabel, This is a 2-D tensor in same shape "
"GTLabel, and score values should in range (0, 1). This "
"input is for GTLabel score can be not 1.0 in image mixup "
"augmentation.")
.AsDispensable();
AddOutput("Loss", AddOutput("Loss",
"The output yolov3 loss tensor, " "The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [N]"); "This is a 1-D tensor with shape of [N]");
...@@ -143,6 +162,9 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -143,6 +162,9 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("ignore_thresh", AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.") "The ignore threshold to ignore confidence loss.")
.SetDefault(0.7); .SetDefault(0.7);
AddAttr<bool>("use_label_smooth",
"Whether to use label smooth. Default True.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
This operator generates yolov3 loss based on given predict result and ground This operator generates yolov3 loss based on given predict result and ground
truth boxes. truth boxes.
...@@ -204,6 +226,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -204,6 +226,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
loss = (loss_{xy} + loss_{wh}) * weight_{box} loss = (loss_{xy} + loss_{wh}) * weight_{box}
+ loss_{conf} + loss_{class} + loss_{conf} + loss_{class}
$$ $$
While :attr:`use_label_smooth` is set to be :attr:`True`, the classification
target will be smoothed when calculating classification loss, target of
positive samples will be smoothed to :math:`1.0 - 1.0 / class\_num` and target of
negetive samples will be smoothed to :math:`1.0 / class\_num`.
While :attr:`GTScore` is given, which means the mixup score of ground truth
boxes, all losses incured by a ground truth box will be multiplied by its
mixup score.
)DOC"); )DOC");
} }
}; };
...@@ -240,6 +271,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { ...@@ -240,6 +271,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("X", Input("X")); op->SetInput("X", Input("X"));
op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTBox", Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel")); op->SetInput("GTLabel", Input("GTLabel"));
op->SetInput("GTScore", Input("GTScore"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("ObjectnessMask", Output("ObjectnessMask"));
op->SetInput("GTMatchMask", Output("GTMatchMask")); op->SetInput("GTMatchMask", Output("GTMatchMask"));
...@@ -249,6 +281,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { ...@@ -249,6 +281,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {});
op->SetOutput(framework::GradVarName("GTScore"), {});
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
}; };
......
...@@ -37,8 +37,8 @@ static T SigmoidCrossEntropy(T x, T label) { ...@@ -37,8 +37,8 @@ static T SigmoidCrossEntropy(T x, T label) {
} }
template <typename T> template <typename T>
static T L2Loss(T x, T y) { static T L1Loss(T x, T y) {
return 0.5 * (y - x) * (y - x); return std::abs(y - x);
} }
template <typename T> template <typename T>
...@@ -47,8 +47,8 @@ static T SigmoidCrossEntropyGrad(T x, T label) { ...@@ -47,8 +47,8 @@ static T SigmoidCrossEntropyGrad(T x, T label) {
} }
template <typename T> template <typename T>
static T L2LossGrad(T x, T y) { static T L1LossGrad(T x, T y) {
return x - y; return x > y ? 1.0 : -1.0;
} }
static int GetMaskIndex(std::vector<int> mask, int val) { static int GetMaskIndex(std::vector<int> mask, int val) {
...@@ -121,47 +121,49 @@ template <typename T> ...@@ -121,47 +121,49 @@ template <typename T>
static void CalcBoxLocationLoss(T* loss, const T* input, Box<T> gt, static void CalcBoxLocationLoss(T* loss, const T* input, Box<T> gt,
std::vector<int> anchors, int an_idx, std::vector<int> anchors, int an_idx,
int box_idx, int gi, int gj, int grid_size, int box_idx, int gi, int gj, int grid_size,
int input_size, int stride) { int input_size, int stride, T score) {
T tx = gt.x * grid_size - gi; T tx = gt.x * grid_size - gi;
T ty = gt.y * grid_size - gj; T ty = gt.y * grid_size - gj;
T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T tw = std::log(gt.w * input_size / anchors[2 * an_idx]);
T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]);
T scale = (2.0 - gt.w * gt.h); T scale = (2.0 - gt.w * gt.h) * score;
loss[0] += SigmoidCrossEntropy<T>(input[box_idx], tx) * scale; loss[0] += SigmoidCrossEntropy<T>(input[box_idx], tx) * scale;
loss[0] += SigmoidCrossEntropy<T>(input[box_idx + stride], ty) * scale; loss[0] += SigmoidCrossEntropy<T>(input[box_idx + stride], ty) * scale;
loss[0] += L2Loss<T>(input[box_idx + 2 * stride], tw) * scale; loss[0] += L1Loss<T>(input[box_idx + 2 * stride], tw) * scale;
loss[0] += L2Loss<T>(input[box_idx + 3 * stride], th) * scale; loss[0] += L1Loss<T>(input[box_idx + 3 * stride], th) * scale;
} }
template <typename T> template <typename T>
static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input,
Box<T> gt, std::vector<int> anchors, Box<T> gt, std::vector<int> anchors,
int an_idx, int box_idx, int gi, int gj, int an_idx, int box_idx, int gi, int gj,
int grid_size, int input_size, int stride) { int grid_size, int input_size, int stride,
T score) {
T tx = gt.x * grid_size - gi; T tx = gt.x * grid_size - gi;
T ty = gt.y * grid_size - gj; T ty = gt.y * grid_size - gj;
T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T tw = std::log(gt.w * input_size / anchors[2 * an_idx]);
T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]);
T scale = (2.0 - gt.w * gt.h); T scale = (2.0 - gt.w * gt.h) * score;
input_grad[box_idx] = input_grad[box_idx] =
SigmoidCrossEntropyGrad<T>(input[box_idx], tx) * scale * loss; SigmoidCrossEntropyGrad<T>(input[box_idx], tx) * scale * loss;
input_grad[box_idx + stride] = input_grad[box_idx + stride] =
SigmoidCrossEntropyGrad<T>(input[box_idx + stride], ty) * scale * loss; SigmoidCrossEntropyGrad<T>(input[box_idx + stride], ty) * scale * loss;
input_grad[box_idx + 2 * stride] = input_grad[box_idx + 2 * stride] =
L2LossGrad<T>(input[box_idx + 2 * stride], tw) * scale * loss; L1LossGrad<T>(input[box_idx + 2 * stride], tw) * scale * loss;
input_grad[box_idx + 3 * stride] = input_grad[box_idx + 3 * stride] =
L2LossGrad<T>(input[box_idx + 3 * stride], th) * scale * loss; L1LossGrad<T>(input[box_idx + 3 * stride], th) * scale * loss;
} }
template <typename T> template <typename T>
static inline void CalcLabelLoss(T* loss, const T* input, const int index, static inline void CalcLabelLoss(T* loss, const T* input, const int index,
const int label, const int class_num, const int label, const int class_num,
const int stride) { const int stride, const T pos, const T neg,
T score) {
for (int i = 0; i < class_num; i++) { for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride]; T pred = input[index + i * stride];
loss[0] += SigmoidCrossEntropy<T>(pred, (i == label) ? 1.0 : 0.0); loss[0] += SigmoidCrossEntropy<T>(pred, (i == label) ? pos : neg) * score;
} }
} }
...@@ -169,11 +171,13 @@ template <typename T> ...@@ -169,11 +171,13 @@ template <typename T>
static inline void CalcLabelLossGrad(T* input_grad, const T loss, static inline void CalcLabelLossGrad(T* input_grad, const T loss,
const T* input, const int index, const T* input, const int index,
const int label, const int class_num, const int label, const int class_num,
const int stride) { const int stride, const T pos, const T neg,
T score) {
for (int i = 0; i < class_num; i++) { for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride]; T pred = input[index + i * stride];
input_grad[index + i * stride] = input_grad[index + i * stride] =
SigmoidCrossEntropyGrad<T>(pred, (i == label) ? 1.0 : 0.0) * loss; SigmoidCrossEntropyGrad<T>(pred, (i == label) ? pos : neg) * score *
loss;
} }
} }
...@@ -188,8 +192,8 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, ...@@ -188,8 +192,8 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness,
for (int l = 0; l < w; l++) { for (int l = 0; l < w; l++) {
T obj = objness[k * w + l]; T obj = objness[k * w + l];
if (obj > 1e-5) { if (obj > 1e-5) {
// positive sample: obj = 1 // positive sample: obj = mixup score
loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 1.0); loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 1.0) * obj;
} else if (obj > -0.5) { } else if (obj > -0.5) {
// negetive sample: obj = 0 // negetive sample: obj = 0
loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 0.0); loss[i] += SigmoidCrossEntropy<T>(input[k * w + l], 0.0);
...@@ -215,7 +219,8 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, ...@@ -215,7 +219,8 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss,
T obj = objness[k * w + l]; T obj = objness[k * w + l];
if (obj > 1e-5) { if (obj > 1e-5) {
input_grad[k * w + l] = input_grad[k * w + l] =
SigmoidCrossEntropyGrad<T>(input[k * w + l], 1.0) * loss[i]; SigmoidCrossEntropyGrad<T>(input[k * w + l], 1.0) * obj *
loss[i];
} else if (obj > -0.5) { } else if (obj > -0.5) {
input_grad[k * w + l] = input_grad[k * w + l] =
SigmoidCrossEntropyGrad<T>(input[k * w + l], 0.0) * loss[i]; SigmoidCrossEntropyGrad<T>(input[k * w + l], 0.0) * loss[i];
...@@ -252,6 +257,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -252,6 +257,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* gt_box = ctx.Input<Tensor>("GTBox"); auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel"); auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* gt_score = ctx.Input<Tensor>("GTScore");
auto* loss = ctx.Output<Tensor>("Loss"); auto* loss = ctx.Output<Tensor>("Loss");
auto* objness_mask = ctx.Output<Tensor>("ObjectnessMask"); auto* objness_mask = ctx.Output<Tensor>("ObjectnessMask");
auto* gt_match_mask = ctx.Output<Tensor>("GTMatchMask"); auto* gt_match_mask = ctx.Output<Tensor>("GTMatchMask");
...@@ -260,6 +266,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -260,6 +266,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh"); float ignore_thresh = ctx.Attr<float>("ignore_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
const int n = input->dims()[0]; const int n = input->dims()[0];
const int h = input->dims()[2]; const int h = input->dims()[2];
...@@ -272,6 +279,13 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -272,6 +279,13 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const int stride = h * w; const int stride = h * w;
const int an_stride = (class_num + 5) * stride; const int an_stride = (class_num + 5) * stride;
T label_pos = 1.0;
T label_neg = 0.0;
if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num);
label_neg = 1.0 / static_cast<T>(class_num);
}
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>(); const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>(); const int* gt_label_data = gt_label->data<int>();
...@@ -283,6 +297,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -283,6 +297,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int* gt_match_mask_data = int* gt_match_mask_data =
gt_match_mask->mutable_data<int>({n, b}, ctx.GetPlace()); gt_match_mask->mutable_data<int>({n, b}, ctx.GetPlace());
const T* gt_score_data;
if (!gt_score) {
Tensor gtscore;
gtscore.mutable_data<T>({n, b}, ctx.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T>()(
ctx.template device_context<platform::CPUDeviceContext>(), &gtscore,
static_cast<T>(1.0));
gt_score = &gtscore;
gt_score_data = gtscore.data<T>();
} else {
gt_score_data = gt_score->data<T>();
}
// calc valid gt box mask, avoid calc duplicately in following code // calc valid gt box mask, avoid calc duplicately in following code
Tensor gt_valid_mask; Tensor gt_valid_mask;
bool* gt_valid_mask_data = bool* gt_valid_mask_data =
...@@ -355,19 +382,20 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -355,19 +382,20 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int mask_idx = GetMaskIndex(anchor_mask, best_n); int mask_idx = GetMaskIndex(anchor_mask, best_n);
gt_match_mask_data[i * b + t] = mask_idx; gt_match_mask_data[i * b + t] = mask_idx;
if (mask_idx >= 0) { if (mask_idx >= 0) {
T score = gt_score_data[i * b + t];
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0); an_stride, stride, 0);
CalcBoxLocationLoss<T>(loss_data + i, input_data, gt, anchors, best_n, CalcBoxLocationLoss<T>(loss_data + i, input_data, gt, anchors, best_n,
box_idx, gi, gj, h, input_size, stride); box_idx, gi, gj, h, input_size, stride, score);
int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi;
obj_mask_data[obj_idx] = 1.0; obj_mask_data[obj_idx] = score;
int label = gt_label_data[i * b + t]; int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5); an_stride, stride, 5);
CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label, CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label,
class_num, stride); class_num, stride, label_pos, label_neg, score);
} }
} }
} }
...@@ -384,6 +412,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -384,6 +412,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* gt_box = ctx.Input<Tensor>("GTBox"); auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel"); auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* gt_score = ctx.Input<Tensor>("GTScore");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss")); auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* objness_mask = ctx.Input<Tensor>("ObjectnessMask"); auto* objness_mask = ctx.Input<Tensor>("ObjectnessMask");
...@@ -392,6 +421,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -392,6 +421,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask"); auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask");
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
const int n = input_grad->dims()[0]; const int n = input_grad->dims()[0];
const int c = input_grad->dims()[1]; const int c = input_grad->dims()[1];
...@@ -404,6 +434,13 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -404,6 +434,13 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
const int stride = h * w; const int stride = h * w;
const int an_stride = (class_num + 5) * stride; const int an_stride = (class_num + 5) * stride;
T label_pos = 1.0;
T label_neg = 0.0;
if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num);
label_neg = 1.0 / static_cast<T>(class_num);
}
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>(); const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>(); const int* gt_label_data = gt_label->data<int>();
...@@ -414,25 +451,41 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -414,25 +451,41 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace()); input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
const T* gt_score_data;
if (!gt_score) {
Tensor gtscore;
gtscore.mutable_data<T>({n, b}, ctx.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T>()(
ctx.template device_context<platform::CPUDeviceContext>(), &gtscore,
static_cast<T>(1.0));
gt_score = &gtscore;
gt_score_data = gtscore.data<T>();
} else {
gt_score_data = gt_score->data<T>();
}
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int t = 0; t < b; t++) { for (int t = 0; t < b; t++) {
int mask_idx = gt_match_mask_data[i * b + t]; int mask_idx = gt_match_mask_data[i * b + t];
if (mask_idx >= 0) { if (mask_idx >= 0) {
T score = gt_score_data[i * b + t];
Box<T> gt = GetGtBox(gt_box_data, i, b, t); Box<T> gt = GetGtBox(gt_box_data, i, b, t);
int gi = static_cast<int>(gt.x * w); int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h); int gj = static_cast<int>(gt.y * h);
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0); an_stride, stride, 0);
CalcBoxLocationLossGrad<T>( CalcBoxLocationLossGrad<T>(input_grad_data, loss_grad_data[i],
input_grad_data, loss_grad_data[i], input_data, gt, anchors, input_data, gt, anchors,
anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); anchor_mask[mask_idx], box_idx, gi, gj, h,
input_size, stride, score);
int label = gt_label_data[i * b + t]; int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5); an_stride, stride, 5);
CalcLabelLossGrad<T>(input_grad_data, loss_grad_data[i], input_data, CalcLabelLossGrad<T>(input_grad_data, loss_grad_data[i], input_data,
label_idx, label, class_num, stride); label_idx, label, class_num, stride, label_pos,
label_neg, score);
} }
} }
} }
......
...@@ -515,6 +515,8 @@ def yolov3_loss(x, ...@@ -515,6 +515,8 @@ def yolov3_loss(x,
class_num, class_num,
ignore_thresh, ignore_thresh,
downsample_ratio, downsample_ratio,
gtscore=None,
use_label_smooth=True,
name=None): name=None):
""" """
${comment} ${comment}
...@@ -533,28 +535,35 @@ def yolov3_loss(x, ...@@ -533,28 +535,35 @@ def yolov3_loss(x,
class_num (int): ${class_num_comment} class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment} ignore_thresh (float): ${ignore_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment} downsample_ratio (int): ${downsample_ratio_comment}
name (string): the name of yolov3 loss name (string): the name of yolov3 loss. Default None.
gtscore (Variable): mixup score of ground truth boxes, shoud be in shape
of [N, B]. Default None.
use_label_smooth (bool): ${use_label_smooth_comment}
Returns: Returns:
Variable: A 1-D tensor with shape [1], the value of yolov3 loss Variable: A 1-D tensor with shape [N], the value of yolov3 loss
Raises: Raises:
TypeError: Input x of yolov3_loss must be Variable TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable" TypeError: Input gtbox of yolov3_loss must be Variable
TypeError: Input gtlabel of yolov3_loss must be Variable" TypeError: Input gtlabel of yolov3_loss must be Variable
TypeError: Input gtscore of yolov3_loss must be None or Variable
TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number TypeError: Attr ignore_thresh of yolov3_loss must be a float number
TypeError: Attr use_label_smooth of yolov3_loss must be a bool value
Examples: Examples:
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') gtbox = fluid.layers.data(name='gtbox', shape=[6, 4], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') gtlabel = fluid.layers.data(name='gtlabel', shape=[6], dtype='int32')
gtscore = fluid.layers.data(name='gtscore', shape=[6], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
anchor_mask = [0, 1, 2] anchor_mask = [0, 1, 2]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel, anchors=anchors, loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel,
gtscore=gtscore, anchors=anchors,
anchor_mask=anchor_mask, class_num=80, anchor_mask=anchor_mask, class_num=80,
ignore_thresh=0.7, downsample_ratio=32) ignore_thresh=0.7, downsample_ratio=32)
""" """
...@@ -566,6 +575,8 @@ def yolov3_loss(x, ...@@ -566,6 +575,8 @@ def yolov3_loss(x,
raise TypeError("Input gtbox of yolov3_loss must be Variable") raise TypeError("Input gtbox of yolov3_loss must be Variable")
if not isinstance(gtlabel, Variable): if not isinstance(gtlabel, Variable):
raise TypeError("Input gtlabel of yolov3_loss must be Variable") raise TypeError("Input gtlabel of yolov3_loss must be Variable")
if gtscore is not None and not isinstance(gtscore, Variable):
raise TypeError("Input gtscore of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple): if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple") raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple):
...@@ -575,6 +586,9 @@ def yolov3_loss(x, ...@@ -575,6 +586,9 @@ def yolov3_loss(x,
if not isinstance(ignore_thresh, float): if not isinstance(ignore_thresh, float):
raise TypeError( raise TypeError(
"Attr ignore_thresh of yolov3_loss must be a float number") "Attr ignore_thresh of yolov3_loss must be a float number")
if not isinstance(use_label_smooth, bool):
raise TypeError(
"Attr use_label_smooth of yolov3_loss must be a bool value")
if name is None: if name is None:
loss = helper.create_variable_for_type_inference(dtype=x.dtype) loss = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -585,21 +599,26 @@ def yolov3_loss(x, ...@@ -585,21 +599,26 @@ def yolov3_loss(x,
objectness_mask = helper.create_variable_for_type_inference(dtype='int32') objectness_mask = helper.create_variable_for_type_inference(dtype='int32')
gt_match_mask = helper.create_variable_for_type_inference(dtype='int32') gt_match_mask = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel,
}
if gtscore:
inputs["GTScore"] = gtscore
attrs = { attrs = {
"anchors": anchors, "anchors": anchors,
"anchor_mask": anchor_mask, "anchor_mask": anchor_mask,
"class_num": class_num, "class_num": class_num,
"ignore_thresh": ignore_thresh, "ignore_thresh": ignore_thresh,
"downsample_ratio": downsample_ratio, "downsample_ratio": downsample_ratio,
"use_label_smooth": use_label_smooth,
} }
helper.append_op( helper.append_op(
type='yolov3_loss', type='yolov3_loss',
inputs={ inputs=inputs,
"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel,
},
outputs={ outputs={
'Loss': loss, 'Loss': loss,
'ObjectnessMask': objectness_mask, 'ObjectnessMask': objectness_mask,
......
...@@ -476,8 +476,16 @@ class TestYoloDetection(unittest.TestCase): ...@@ -476,8 +476,16 @@ class TestYoloDetection(unittest.TestCase):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], gtscore = layers.data(name='gtscore', shape=[10], dtype='float32')
[0, 1], 10, 0.7, 32) loss = layers.yolov3_loss(
x,
gtbox,
gtlabel, [10, 13, 30, 13], [0, 1],
10,
0.7,
32,
gtscore=gtscore,
use_label_smooth=False)
self.assertIsNotNone(loss) self.assertIsNotNone(loss)
......
...@@ -23,8 +23,8 @@ from op_test import OpTest ...@@ -23,8 +23,8 @@ from op_test import OpTest
from paddle.fluid import core from paddle.fluid import core
def l2loss(x, y): def l1loss(x, y):
return 0.5 * (y - x) * (y - x) return abs(x - y)
def sce(x, label): def sce(x, label):
...@@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): ...@@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2):
return inter_area / union return inter_area / union
def YOLOv3Loss(x, gtbox, gtlabel, attrs): def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
n, c, h, w = x.shape n, c, h, w = x.shape
b = gtbox.shape[1] b = gtbox.shape[1]
anchors = attrs['anchors'] anchors = attrs['anchors']
...@@ -75,21 +75,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ...@@ -75,21 +75,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
mask_num = len(anchor_mask) mask_num = len(anchor_mask)
class_num = attrs["class_num"] class_num = attrs["class_num"]
ignore_thresh = attrs['ignore_thresh'] ignore_thresh = attrs['ignore_thresh']
downsample = attrs['downsample'] downsample_ratio = attrs['downsample_ratio']
input_size = downsample * h use_label_smooth = attrs['use_label_smooth']
input_size = downsample_ratio * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32') loss = np.zeros((n)).astype('float32')
label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0
label_neg = 1.0 / class_num if use_label_smooth else 0.0
pred_box = x[:, :, :, :, :4].copy() pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:],
np.ones_like(x[:, :, :, :, 5:]) * 1.0 /
class_num)
mask_anchors = [] mask_anchors = []
for m in anchor_mask: for m in anchor_mask:
mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) mask_anchors.append((anchors[2 * m], anchors[2 * m + 1]))
...@@ -138,21 +138,22 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ...@@ -138,21 +138,22 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
ty = gtbox[i, j, 1] * w - gj ty = gtbox[i, j, 1] * w - gj
tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0])
th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1])
scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j]
loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale
loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale
loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale
loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale
objness[i, an_idx * h * w + gj * w + gi] = 1.0 objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j]
for label_idx in range(class_num): for label_idx in range(class_num):
loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos
float(label_idx == gtlabel[i, j])) if label_idx == gtlabel[i, j] else
label_neg) * gtscore[i, j]
for j in range(mask_num * h * w): for j in range(mask_num * h * w):
if objness[i, j] > 0: if objness[i, j] > 0:
loss[i] += sce(pred_obj[i, j], 1.0) loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j]
elif objness[i, j] == 0: elif objness[i, j] == 0:
loss[i] += sce(pred_obj[i, j], 0.0) loss[i] += sce(pred_obj[i, j], 0.0)
...@@ -176,7 +177,8 @@ class TestYolov3LossOp(OpTest): ...@@ -176,7 +177,8 @@ class TestYolov3LossOp(OpTest):
"anchor_mask": self.anchor_mask, "anchor_mask": self.anchor_mask,
"class_num": self.class_num, "class_num": self.class_num,
"ignore_thresh": self.ignore_thresh, "ignore_thresh": self.ignore_thresh,
"downsample": self.downsample, "downsample_ratio": self.downsample_ratio,
"use_label_smooth": self.use_label_smooth,
} }
self.inputs = { self.inputs = {
...@@ -184,7 +186,14 @@ class TestYolov3LossOp(OpTest): ...@@ -184,7 +186,14 @@ class TestYolov3LossOp(OpTest):
'GTBox': gtbox.astype('float32'), 'GTBox': gtbox.astype('float32'),
'GTLabel': gtlabel.astype('int32'), 'GTLabel': gtlabel.astype('int32'),
} }
loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs)
gtscore = np.ones(self.gtbox_shape[:2]).astype('float32')
if self.gtscore:
gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32')
self.inputs['GTScore'] = gtscore
loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore,
self.attrs)
self.outputs = { self.outputs = {
'Loss': loss, 'Loss': loss,
'ObjectnessMask': objness, 'ObjectnessMask': objness,
...@@ -193,24 +202,57 @@ class TestYolov3LossOp(OpTest): ...@@ -193,24 +202,57 @@ class TestYolov3LossOp(OpTest):
def test_check_output(self): def test_check_output(self):
place = core.CPUPlace() place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=2e-3)
def test_check_grad_ignore_gtbox(self): def test_check_grad_ignore_gtbox(self):
place = core.CPUPlace() place = core.CPUPlace()
self.check_grad_with_place( self.check_grad_with_place(place, ['X'], 'Loss', max_relative_error=0.2)
place, ['X'],
'Loss', def initTestCase(self):
no_grad_set=set(["GTBox", "GTLabel"]), self.anchors = [
max_relative_error=0.3) 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = True
class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
def initTestCase(self):
self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = False
class TestYolov3LossNoGTScore(TestYolov3LossOp):
def initTestCase(self): def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23] self.anchors = [
self.anchor_mask = [1, 2] 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5 self.class_num = 5
self.ignore_thresh = 0.5 self.ignore_thresh = 0.7
self.downsample = 32 self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4) self.gtbox_shape = (3, 5, 4)
self.gtscore = False
self.use_label_smooth = True
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册