diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.h b/paddle/fluid/operators/detection/yolov3_loss_op.h index a004b022b75174012d10ba38e5ec161830c62640..f8d49960c7c5e718d68e7af2bea3dec825fc35fd 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.h +++ b/paddle/fluid/operators/detection/yolov3_loss_op.h @@ -282,8 +282,9 @@ class Yolov3LossKernel : public framework::OpKernel { T label_pos = 1.0; T label_neg = 0.0; if (use_label_smooth) { - label_pos = 1.0 - 1.0 / static_cast(class_num); - label_neg = 1.0 / static_cast(class_num); + T smooth_weight = std::min(1.0 / static_cast(class_num), 1.0 / 40); + label_pos = 1.0 - smooth_weight; + label_neg = smooth_weight; } const T* input_data = input->data(); @@ -437,8 +438,9 @@ class Yolov3LossGradKernel : public framework::OpKernel { T label_pos = 1.0; T label_neg = 0.0; if (use_label_smooth) { - label_pos = 1.0 - 1.0 / static_cast(class_num); - label_neg = 1.0 / static_cast(class_num); + T smooth_weight = std::min(1.0 / static_cast(class_num), 1.0 / 40); + label_pos = 1.0 - smooth_weight; + label_neg = smooth_weight; } const T* input_data = input->data(); diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index e4d6edc72c0ca888e271101f079cdcc6fb4e8a70..623e2228a4c2865c65277f44ad92a2060c18b49a 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -81,8 +81,9 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') - label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 - label_neg = 1.0 / class_num if use_label_smooth else 0.0 + smooth_weight = min(1.0 / class_num, 1.0 / 40) + label_pos = 1.0 - smooth_weight if use_label_smooth else 1.0 + label_neg = smooth_weight if use_label_smooth else 0.0 pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))