diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 057ce9341b7bd4296806d02ec4349b09b23d2f1f..6b25d6a14f54055f37e968a2a63e81baa536063a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -330,7 +330,7 @@ paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '587845f60c5d97ffdf2dfd21da52eca1')) paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e')) paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b')) -paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691')) +paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index ab01bdf7ca8c5a369bd8838b1acc734364666992..6c37da17f4011d38efcdc5406331f1be173dd0dd 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -10,6 +10,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detection/yolov3_loss_op.h" +#include #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -72,6 +73,18 @@ class Yolov3LossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); + if (ctx->HasInput("GTScore")) { + auto dim_gtscore = ctx->GetInputDim("GTScore"); + PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2, + "Input(GTScore) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ( + dim_gtscore[0], dim_gtbox[0], + "Input(GTBox) and Input(GTScore) dim[0] should be same"); + PADDLE_ENFORCE_EQ( + dim_gtscore[1], dim_gtbox[1], + "Input(GTBox) and Input(GTScore) dim[1] should be same"); + } + std::vector dim_out({dim_x[0]}); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); @@ -112,6 +125,12 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "This is a 2-D tensor with shape of [N, max_box_num], " "and each element should be an integer to indicate the " "box class id."); + AddInput("GTScore", + "The score of GTLabel, This is a 2-D tensor in same shape " + "GTLabel, and score values should in range (0, 1). This " + "input is for GTLabel score can be not 1.0 in image mixup " + "augmentation.") + .AsDispensable(); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); @@ -143,6 +162,9 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); + AddAttr("use_label_smooth", + "Whether to use label smooth. Default True.") + .SetDefault(true); AddComment(R"DOC( This operator generates yolov3 loss based on given predict result and ground truth boxes. @@ -204,6 +226,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { loss = (loss_{xy} + loss_{wh}) * weight_{box} + loss_{conf} + loss_{class} $$ + + While :attr:`use_label_smooth` is set to be :attr:`True`, the classification + target will be smoothed when calculating classification loss, target of + positive samples will be smoothed to :math:`1.0 - 1.0 / class\_num` and target of + negetive samples will be smoothed to :math:`1.0 / class\_num`. + + While :attr:`GTScore` is given, which means the mixup score of ground truth + boxes, all losses incured by a ground truth box will be multiplied by its + mixup score. )DOC"); } }; @@ -240,6 +271,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); + op->SetInput("GTScore", Input("GTScore")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("GTMatchMask", Output("GTMatchMask")); @@ -249,6 +281,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {}); + op->SetOutput(framework::GradVarName("GTScore"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.h b/paddle/fluid/operators/detection/yolov3_loss_op.h index 8407d4e6e8f87a2e8d073c4fbda5691abe1bba68..a004b022b75174012d10ba38e5ec161830c62640 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.h +++ b/paddle/fluid/operators/detection/yolov3_loss_op.h @@ -37,8 +37,8 @@ static T SigmoidCrossEntropy(T x, T label) { } template -static T L2Loss(T x, T y) { - return 0.5 * (y - x) * (y - x); +static T L1Loss(T x, T y) { + return std::abs(y - x); } template @@ -47,8 +47,8 @@ static T SigmoidCrossEntropyGrad(T x, T label) { } template -static T L2LossGrad(T x, T y) { - return x - y; +static T L1LossGrad(T x, T y) { + return x > y ? 1.0 : -1.0; } static int GetMaskIndex(std::vector mask, int val) { @@ -121,47 +121,49 @@ template static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, int grid_size, - int input_size, int stride) { + int input_size, int stride, T score) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = (2.0 - gt.w * gt.h); + T scale = (2.0 - gt.w * gt.h) * score; loss[0] += SigmoidCrossEntropy(input[box_idx], tx) * scale; loss[0] += SigmoidCrossEntropy(input[box_idx + stride], ty) * scale; - loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; - loss[0] += L2Loss(input[box_idx + 3 * stride], th) * scale; + loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; + loss[0] += L1Loss(input[box_idx + 3 * stride], th) * scale; } template static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, - int grid_size, int input_size, int stride) { + int grid_size, int input_size, int stride, + T score) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = (2.0 - gt.w * gt.h); + T scale = (2.0 - gt.w * gt.h) * score; input_grad[box_idx] = SigmoidCrossEntropyGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = SigmoidCrossEntropyGrad(input[box_idx + stride], ty) * scale * loss; input_grad[box_idx + 2 * stride] = - L2LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; + L1LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; input_grad[box_idx + 3 * stride] = - L2LossGrad(input[box_idx + 3 * stride], th) * scale * loss; + L1LossGrad(input[box_idx + 3 * stride], th) * scale * loss; } template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, - const int stride) { + const int stride, const T pos, const T neg, + T score) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SigmoidCrossEntropy(pred, (i == label) ? 1.0 : 0.0); + loss[0] += SigmoidCrossEntropy(pred, (i == label) ? pos : neg) * score; } } @@ -169,11 +171,13 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const int class_num, - const int stride) { + const int stride, const T pos, const T neg, + T score) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SigmoidCrossEntropyGrad(pred, (i == label) ? 1.0 : 0.0) * loss; + SigmoidCrossEntropyGrad(pred, (i == label) ? pos : neg) * score * + loss; } } @@ -188,8 +192,8 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - // positive sample: obj = 1 - loss[i] += SigmoidCrossEntropy(input[k * w + l], 1.0); + // positive sample: obj = mixup score + loss[i] += SigmoidCrossEntropy(input[k * w + l], 1.0) * obj; } else if (obj > -0.5) { // negetive sample: obj = 0 loss[i] += SigmoidCrossEntropy(input[k * w + l], 0.0); @@ -215,7 +219,8 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, T obj = objness[k * w + l]; if (obj > 1e-5) { input_grad[k * w + l] = - SigmoidCrossEntropyGrad(input[k * w + l], 1.0) * loss[i]; + SigmoidCrossEntropyGrad(input[k * w + l], 1.0) * obj * + loss[i]; } else if (obj > -0.5) { input_grad[k * w + l] = SigmoidCrossEntropyGrad(input[k * w + l], 0.0) * loss[i]; @@ -252,6 +257,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* loss = ctx.Output("Loss"); auto* objness_mask = ctx.Output("ObjectnessMask"); auto* gt_match_mask = ctx.Output("GTMatchMask"); @@ -260,6 +266,7 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -272,6 +279,13 @@ class Yolov3LossKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + T label_pos = 1.0; + T label_neg = 0.0; + if (use_label_smooth) { + label_pos = 1.0 - 1.0 / static_cast(class_num); + label_neg = 1.0 / static_cast(class_num); + } + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); @@ -283,6 +297,19 @@ class Yolov3LossKernel : public framework::OpKernel { int* gt_match_mask_data = gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); + const T* gt_score_data; + if (!gt_score) { + Tensor gtscore; + gtscore.mutable_data({n, b}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), >score, + static_cast(1.0)); + gt_score = >score; + gt_score_data = gtscore.data(); + } else { + gt_score_data = gt_score->data(); + } + // calc valid gt box mask, avoid calc duplicately in following code Tensor gt_valid_mask; bool* gt_valid_mask_data = @@ -355,19 +382,20 @@ class Yolov3LossKernel : public framework::OpKernel { int mask_idx = GetMaskIndex(anchor_mask, best_n); gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { + T score = gt_score_data[i * b + t]; int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, - box_idx, gi, gj, h, input_size, stride); + box_idx, gi, gj, h, input_size, stride, score); int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - obj_mask_data[obj_idx] = 1.0; + obj_mask_data[obj_idx] = score; int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, - class_num, stride); + class_num, stride, label_pos, label_neg, score); } } } @@ -384,6 +412,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -392,6 +421,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); int downsample_ratio = ctx.Attr("downsample_ratio"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -404,6 +434,13 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + T label_pos = 1.0; + T label_neg = 0.0; + if (use_label_smooth) { + label_pos = 1.0 - 1.0 / static_cast(class_num); + label_neg = 1.0 / static_cast(class_num); + } + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); @@ -414,25 +451,41 @@ class Yolov3LossGradKernel : public framework::OpKernel { input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); + const T* gt_score_data; + if (!gt_score) { + Tensor gtscore; + gtscore.mutable_data({n, b}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), >score, + static_cast(1.0)); + gt_score = >score; + gt_score_data = gtscore.data(); + } else { + gt_score_data = gt_score->data(); + } + for (int i = 0; i < n; i++) { for (int t = 0; t < b; t++) { int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { + T score = gt_score_data[i * b + t]; Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad( - input_grad_data, loss_grad_data[i], input_data, gt, anchors, - anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); + CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], + input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, + input_size, stride, score); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, class_num, stride); + label_idx, label, class_num, stride, label_pos, + label_neg, score); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index cbedd70f857b3f767492826cda08ae1171d72bad..9183bfd43b1d6f903d56fc275355b787a53ef511 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -515,6 +515,8 @@ def yolov3_loss(x, class_num, ignore_thresh, downsample_ratio, + gtscore=None, + use_label_smooth=True, name=None): """ ${comment} @@ -533,28 +535,35 @@ def yolov3_loss(x, class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} downsample_ratio (int): ${downsample_ratio_comment} - name (string): the name of yolov3 loss + name (string): the name of yolov3 loss. Default None. + gtscore (Variable): mixup score of ground truth boxes, shoud be in shape + of [N, B]. Default None. + use_label_smooth (bool): ${use_label_smooth_comment} Returns: - Variable: A 1-D tensor with shape [1], the value of yolov3 loss + Variable: A 1-D tensor with shape [N], the value of yolov3 loss Raises: TypeError: Input x of yolov3_loss must be Variable - TypeError: Input gtbox of yolov3_loss must be Variable" - TypeError: Input gtlabel of yolov3_loss must be Variable" + TypeError: Input gtbox of yolov3_loss must be Variable + TypeError: Input gtlabel of yolov3_loss must be Variable + TypeError: Input gtscore of yolov3_loss must be None or Variable TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number + TypeError: Attr use_label_smooth of yolov3_loss must be a bool value Examples: .. code-block:: python x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') - gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') - gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') + gtbox = fluid.layers.data(name='gtbox', shape=[6, 4], dtype='float32') + gtlabel = fluid.layers.data(name='gtlabel', shape=[6], dtype='int32') + gtscore = fluid.layers.data(name='gtscore', shape=[6], dtype='float32') anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] anchor_mask = [0, 1, 2] - loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel, anchors=anchors, + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel, + gtscore=gtscore, anchors=anchors, anchor_mask=anchor_mask, class_num=80, ignore_thresh=0.7, downsample_ratio=32) """ @@ -566,6 +575,8 @@ def yolov3_loss(x, raise TypeError("Input gtbox of yolov3_loss must be Variable") if not isinstance(gtlabel, Variable): raise TypeError("Input gtlabel of yolov3_loss must be Variable") + if gtscore is not None and not isinstance(gtscore, Variable): + raise TypeError("Input gtscore of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): @@ -575,6 +586,9 @@ def yolov3_loss(x, if not isinstance(ignore_thresh, float): raise TypeError( "Attr ignore_thresh of yolov3_loss must be a float number") + if not isinstance(use_label_smooth, bool): + raise TypeError( + "Attr use_label_smooth of yolov3_loss must be a bool value") if name is None: loss = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -585,21 +599,26 @@ def yolov3_loss(x, objectness_mask = helper.create_variable_for_type_inference(dtype='int32') gt_match_mask = helper.create_variable_for_type_inference(dtype='int32') + inputs = { + "X": x, + "GTBox": gtbox, + "GTLabel": gtlabel, + } + if gtscore: + inputs["GTScore"] = gtscore + attrs = { "anchors": anchors, "anchor_mask": anchor_mask, "class_num": class_num, "ignore_thresh": ignore_thresh, "downsample_ratio": downsample_ratio, + "use_label_smooth": use_label_smooth, } helper.append_op( type='yolov3_loss', - inputs={ - "X": x, - "GTBox": gtbox, - "GTLabel": gtlabel, - }, + inputs=inputs, outputs={ 'Loss': loss, 'ObjectnessMask': objectness_mask, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 6218db73459a2bb55d72545c738f88dbd8cce0f7..b756c532cad00b57c39349c5a6959b38f70f09fa 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -476,8 +476,16 @@ class TestYoloDetection(unittest.TestCase): x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') - loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], - [0, 1], 10, 0.7, 32) + gtscore = layers.data(name='gtscore', shape=[10], dtype='float32') + loss = layers.yolov3_loss( + x, + gtbox, + gtlabel, [10, 13, 30, 13], [0, 1], + 10, + 0.7, + 32, + gtscore=gtscore, + use_label_smooth=False) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 020c1139230a9177c4d7765367359d91839d7d46..e4d6edc72c0ca888e271101f079cdcc6fb4e8a70 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,8 +23,8 @@ from op_test import OpTest from paddle.fluid import core -def l2loss(x, y): - return 0.5 * (y - x) * (y - x) +def l1loss(x, y): + return abs(x - y) def sce(x, label): @@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): return inter_area / union -def YOLOv3Loss(x, gtbox, gtlabel, attrs): +def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): n, c, h, w = x.shape b = gtbox.shape[1] anchors = attrs['anchors'] @@ -75,21 +75,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): mask_num = len(anchor_mask) class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] - downsample = attrs['downsample'] - input_size = downsample * h + downsample_ratio = attrs['downsample_ratio'] + use_label_smooth = attrs['use_label_smooth'] + input_size = downsample_ratio * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') + label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 + label_neg = 1.0 / class_num if use_label_smooth else 0.0 + pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h - x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:], - np.ones_like(x[:, :, :, :, 5:]) * 1.0 / - class_num) - mask_anchors = [] for m in anchor_mask: mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) @@ -138,21 +138,22 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ty = gtbox[i, j, 1] * w - gj tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) - scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) + scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale - loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale - loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale - objness[i, an_idx * h * w + gj * w + gi] = 1.0 + objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] for label_idx in range(class_num): - loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], - float(label_idx == gtlabel[i, j])) + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos + if label_idx == gtlabel[i, j] else + label_neg) * gtscore[i, j] for j in range(mask_num * h * w): if objness[i, j] > 0: - loss[i] += sce(pred_obj[i, j], 1.0) + loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j] elif objness[i, j] == 0: loss[i] += sce(pred_obj[i, j], 0.0) @@ -176,7 +177,8 @@ class TestYolov3LossOp(OpTest): "anchor_mask": self.anchor_mask, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, - "downsample": self.downsample, + "downsample_ratio": self.downsample_ratio, + "use_label_smooth": self.use_label_smooth, } self.inputs = { @@ -184,7 +186,14 @@ class TestYolov3LossOp(OpTest): 'GTBox': gtbox.astype('float32'), 'GTLabel': gtlabel.astype('int32'), } - loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + + gtscore = np.ones(self.gtbox_shape[:2]).astype('float32') + if self.gtscore: + gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32') + self.inputs['GTScore'] = gtscore + + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore, + self.attrs) self.outputs = { 'Loss': loss, 'ObjectnessMask': objness, @@ -193,24 +202,57 @@ class TestYolov3LossOp(OpTest): def test_check_output(self): place = core.CPUPlace() - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=2e-3) def test_check_grad_ignore_gtbox(self): place = core.CPUPlace() - self.check_grad_with_place( - place, ['X'], - 'Loss', - no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.3) + self.check_grad_with_place(place, ['X'], 'Loss', max_relative_error=0.2) + + def initTestCase(self): + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] + self.class_num = 5 + self.ignore_thresh = 0.7 + self.downsample_ratio = 32 + self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 5, 4) + self.gtscore = True + self.use_label_smooth = True + + +class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): + def initTestCase(self): + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] + self.class_num = 5 + self.ignore_thresh = 0.7 + self.downsample_ratio = 32 + self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 5, 4) + self.gtscore = True + self.use_label_smooth = False + +class TestYolov3LossNoGTScore(TestYolov3LossOp): def initTestCase(self): - self.anchors = [10, 13, 16, 30, 33, 23] - self.anchor_mask = [1, 2] + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] self.class_num = 5 - self.ignore_thresh = 0.5 - self.downsample = 32 + self.ignore_thresh = 0.7 + self.downsample_ratio = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 5, 4) + self.gtscore = False + self.use_label_smooth = True if __name__ == "__main__":