diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bf0916a0767f060adb9033aad1d602c3ca516e31..d773c2518cd769a69f4eeed905a07c45e065a1d9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'label_smooth', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index c146035f9dacd1b6581b65f05272acfcf8cecb1b..0c5426728b75a8daa97675e63aa2de4fef871c48 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -46,6 +46,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); int mask_num = anchor_mask.size(); auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); @@ -156,6 +157,8 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); + AddAttr("use_label_smooth", "bool,default True", "use label smooth") + .SetDefault(true); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index c4095b8ca561d758158e82390d82401e3ada7ab3..f601651f0602fc0f00bacb6c0bb05b85e02ab115 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -157,11 +157,19 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const T score, - const int class_num, const int stride) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? score : 0.0); + const int class_num, const int stride, + const bool use_label_smooth) { + if (use_label_smooth) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + loss[0] += SCE(pred, (i == label) ? score : 0.0); + } + } else { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + loss[0] += SCE(pred, (i == label) ? score : 0.0); + } } } @@ -169,12 +177,21 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const T score, - const int class_num, const int stride) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; + const int class_num, const int stride, + const bool use_label_smooth) { + if (use_label_smooth) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? score : 0.0) * loss; + } + } else { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? score : 0.0) * loss; + } } } @@ -255,6 +272,7 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -364,7 +382,7 @@ class Yolov3LossKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, - class_num, stride); + class_num, stride, use_label_smooth); } } } @@ -390,6 +408,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); int downsample = ctx.Attr("downsample"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -432,7 +451,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, score, class_num, stride); + label_idx, label, score, class_num, stride, + use_label_smooth); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 10573cc4c60b0e51b361f5621bc5a842453e308d..e984576ffe8ce9549925b23ee79b24b9b7773d0a 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -418,6 +418,7 @@ def yolov3_loss(x, class_num, ignore_thresh, downsample, + use_label_smooth=True, name=None): """ ${comment} @@ -438,6 +439,7 @@ def yolov3_loss(x, class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} downsample (int): ${downsample_comment} + use_label_smooth(bool): ${use_label_smooth_comment} name (string): the name of yolov3 loss Returns: @@ -451,6 +453,7 @@ def yolov3_loss(x, TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number + TypeError: Attr use_label_smooth of yolov3_loss must be a bool value Examples: .. code-block:: python @@ -479,6 +482,8 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(class_num, int): + raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( "Attr ignore_thresh of yolov3_loss must be a float number") @@ -498,6 +503,7 @@ def yolov3_loss(x, "class_num": class_num, "ignore_thresh": ignore_thresh, "downsample": downsample, + "use_label_smooth": use_label_smooth } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index c65570d7c15334ca42e8ba163542f03d020d4b0e..1746a1da1dc48b322e8266861036751751199586 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -76,6 +76,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] + #use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') @@ -176,6 +177,7 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "downsample": self.downsample, + "use_label_smooth": self.use_label_smooth, } self.inputs = { @@ -215,6 +217,12 @@ class TestYolov3LossOp(OpTest): self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 10, 4) + self.use_label_smooth = True + + +class TestYolov3LossWithLabelSmooth(TestYolov3LossOp): + def set_label_smooth(self): + self.use_label_smooth = True if __name__ == "__main__":