From 192d293854b93d86bbb27ed37af199dd6e4ee1c6 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 6 Dec 2018 19:53:41 +0800 Subject: [PATCH] use stable Sigmoid Cross Entropy implement. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 + paddle/fluid/operators/yolov3_loss_op.h | 283 ++++++++++-------- python/paddle/fluid/layers/detection.py | 3 + python/paddle/fluid/tests/test_detection.py | 2 +- .../tests/unittests/test_yolov3_loss_op.py | 90 +++--- 5 files changed, 208 insertions(+), 174 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 60508f7ab..66d618de5 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -99,6 +99,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("anchors", "The anchor width and height, " "it will be parsed pair by pair."); + AddAttr("input_size", + "The input size of YOLOv3 net, " + "generally this is set as 320, 416 or 608.") + .SetDefault(406); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); AddAttr("loss_weight_xy", "The weight of x, y location loss.") diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 0bb285722..fac06b420 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,87 +33,91 @@ static inline bool isZero(T x) { } template -static inline T sigmoid(T x) { - return 1.0 / (exp(-1.0 * x) + 1.0); -} +static inline T CalcMSEWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, const T mf) { + int numel = static_cast(x.numel()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); -template -static inline T CalcMaskPointNum(const Tensor& mask) { - auto mask_t = EigenVector::Flatten(mask); - T count = 0.0; - for (int i = 0; i < mask_t.dimensions()[0]; i++) { - if (mask_t(i)) { - count += 1.0; - } + T error_sum = 0.0; + for (int i = 0; i < numel; i++) { + T xi = x_data[i]; + T yi = y_data[i]; + T weighti = weight_data[i]; + error_sum += pow(yi - xi, 2) * weighti; } - return count; + + return error_sum / mf; } template -static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - T error_sum = 0.0; - T points = 0.0; - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - error_sum += pow(x_t(i) - y_t(i), 2); - points += 1; - } +static void CalcMSEGradWithWeight(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& weight, + const T mf) { + int numel = static_cast(grad->numel()); + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < numel; i++) { + grad_data[i] = 2.0 * weight_data[i] * (x_data[i] - y_data[i]) / mf; } - return (error_sum / points); } template -static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, - const Tensor& mask, T mf) { - auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; - } +struct SigmoidCrossEntropyForward { + T operator()(const T& x, const T& label) const { + T term1 = (x > 0) ? x : 0; + T term2 = x * label; + T term3 = std::log(static_cast(1.0) + std::exp(-(std::abs(x)))); + return term1 - term2 + term3; } -} +}; template -static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); +struct SigmoidCrossEntropyBackward { + T operator()(const T& x, const T& label) const { + T sigmoid_x = + static_cast(1.0) / (static_cast(1.0) + std::exp(-1.0 * x)); + return sigmoid_x - label; + } +}; - T error_sum = 0.0; - T points = 0.0; - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - error_sum += - -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i))); - points += 1; - } +template +static inline T CalcSCEWithWeight(const Tensor& x, const Tensor& labels, + const Tensor& weight, const T mf) { + int numel = x.numel(); + const T* x_data = x.data(); + const T* labels_data = labels.data(); + const T* weight_data = weight.data(); + + T loss = 0.0; + for (int i = 0; i < numel; i++) { + T xi = x_data[i]; + T labeli = labels_data[i]; + T weighti = weight_data[i]; + loss += ((xi > 0.0 ? xi : 0.0) - xi * labeli + + std::log(1.0 + std::exp(-1.0 * std::abs(xi)))) * + weighti; } - return (error_sum / points); + return loss / mf; } template -static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, - const Tensor& y, const Tensor& mask, - T mf) { - auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; - } +static inline void CalcSCEGradWithWeight(Tensor* grad, const Tensor& x, + const Tensor& labels, + const Tensor& weight, const T mf) { + int numel = grad->numel(); + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* labels_data = labels.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < numel; i++) { + grad_data[i] = (1.0 / (1.0 + std::exp(-1.0 * x_data[i])) - labels_data[i]) * + weight_data[i] / mf; } } @@ -139,21 +143,20 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_conf, for (int an_idx = 0; an_idx < anchor_num; an_idx++) { for (int j = 0; j < h; j++) { for (int k = 0; k < w; k++) { - pred_x_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx, j, k)); + pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, k); pred_y_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); + input_t(i, box_attr_num * an_idx + 1, j, k); pred_w_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 2, j, k); pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); pred_conf_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); + input_t(i, box_attr_num * an_idx + 4, j, k); for (int c = 0; c < class_num; c++) { pred_class_t(i, an_idx, j, k, c) = - sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + input_t(i, box_attr_num * an_idx + 5 + c, j, k); } } } @@ -188,21 +191,22 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { template static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const float ignore_thresh, std::vector anchors, - const int grid_size, Tensor* obj_mask, - Tensor* noobj_mask, Tensor* tx, Tensor* ty, - Tensor* tw, Tensor* th, Tensor* tconf, - Tensor* tclass) { + const int input_size, const int grid_size, + Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, + Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; const int b = gt_box.dims()[1]; const int anchor_num = anchors.size() / 2; auto gt_box_t = EigenTensor::From(gt_box); auto gt_label_t = EigenTensor::From(gt_label); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); - auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); + auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); auto ty_t = EigenTensor::From(*ty).setConstant(0.0); auto tw_t = EigenTensor::From(*tw).setConstant(0.0); auto th_t = EigenTensor::From(*th).setConstant(0.0); + auto tweight_t = EigenTensor::From(*tweight).setConstant(0.0); auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); @@ -216,8 +220,8 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, int cur_label = gt_label_t(i, j); T gx = gt_box_t(i, j, 0) * grid_size; T gy = gt_box_t(i, j, 1) * grid_size; - T gw = gt_box_t(i, j, 2) * grid_size; - T gh = gt_box_t(i, j, 3) * grid_size; + T gw = gt_box_t(i, j, 2) * input_size; + T gh = gt_box_t(i, j, 3) * input_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -234,15 +238,17 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(i, an_idx, gj, gi) = 0; + noobj_mask_t(i, an_idx, gj, gi) = static_cast(0.0); } } - obj_mask_t(i, best_an_index, gj, gi) = 1; - noobj_mask_t(i, best_an_index, gj, gi) = 0; + obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); + noobj_mask_t(i, best_an_index, gj, gi) = static_cast(0.0); tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); + tweight_t(i, best_an_index, gj, gi) = + 2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3); tclass_t(i, best_an_index, gj, gi, cur_label) = 1; tconf_t(i, best_an_index, gj, gi) = 1; } @@ -295,27 +301,22 @@ static void AddAllGradToInputGrad( for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { grad_t(i, j * attr_num, k, l) = - grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * - (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy; + grad_x_t(i, j, k, l) * loss * loss_weight_xy; grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * - (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy; + grad_y_t(i, j, k, l) * loss * loss_weight_xy; grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target; + grad_conf_target_t(i, j, k, l) * loss * loss_weight_conf_target; grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * + grad_conf_notarget_t(i, j, k, l) * loss * loss_weight_conf_notarget; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * - (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class; + grad_class_t(i, j, k, l, c) * loss * loss_weight_class; } } } @@ -333,6 +334,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); + int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); @@ -358,30 +360,46 @@ class Yolov3LossKernel : public framework::OpKernel { &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; - Tensor tx, ty, tw, th, tconf, tclass; - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + Tensor tx, ty, tw, th, tweight, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, - &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, + h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + &tconf, &tclass); + + Tensor obj_weight; + obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + auto obj_weight_t = EigenTensor::From(obj_weight); + auto obj_mask_t = EigenTensor::From(obj_mask); + auto tweight_t = EigenTensor::From(tweight); + obj_weight_t = obj_mask_t * tweight_t; Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); - - T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); - T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); - T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); - T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_target = CalcBCEWithMask(pred_conf, tconf, obj_mask); - T loss_conf_notarget = CalcBCEWithMask(pred_conf, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); + + T box_f = static_cast(an_num * h * w); + T class_f = static_cast(an_num * h * w * class_num); + T loss_x = CalcSCEWithWeight(pred_x, tx, obj_weight, box_f); + T loss_y = CalcSCEWithWeight(pred_y, ty, obj_weight, box_f); + T loss_w = CalcMSEWithWeight(pred_w, tw, obj_weight, box_f); + T loss_h = CalcMSEWithWeight(pred_h, th, obj_weight, box_f); + T loss_conf_target = + CalcSCEWithWeight(pred_conf, tconf, obj_mask, box_f); + T loss_conf_notarget = + CalcSCEWithWeight(pred_conf, tconf, noobj_mask, box_f); + T loss_class = + CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, class_f); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); loss_data[0] = loss_weight_xy * (loss_x + loss_y) + @@ -405,6 +423,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Loss")); const T loss = output_grad->data()[0]; + int input_size = ctx.Attr("input_size"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); @@ -430,22 +449,33 @@ class Yolov3LossGradKernel : public framework::OpKernel { &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; - Tensor tx, ty, tw, th, tconf, tclass; - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + Tensor tx, ty, tw, th, tweight, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, - &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, + h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + &tconf, &tclass); + + Tensor obj_weight; + obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + auto obj_weight_t = EigenTensor::From(obj_weight); + auto obj_mask_t = EigenTensor::From(obj_mask); + auto tweight_t = EigenTensor::From(tweight); + obj_weight_t = obj_mask_t * tweight_t; Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); Tensor grad_x, grad_y, grad_w, grad_h; Tensor grad_conf_target, grad_conf_notarget, grad_class; @@ -456,19 +486,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - T obj_mf = CalcMaskPointNum(obj_mask); - T noobj_mf = CalcMaskPointNum(noobj_mask); - T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); - CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); - CalcBCEGradWithMask(&grad_conf_target, pred_conf, tconf, obj_mask, - obj_mf); - CalcBCEGradWithMask(&grad_conf_notarget, pred_conf, tconf, noobj_mask, - noobj_mf); - CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, - obj_expand_mf); + T box_f = static_cast(an_num * h * w); + T class_f = static_cast(an_num * h * w * class_num); + CalcSCEGradWithWeight(&grad_x, pred_x, tx, obj_weight, box_f); + CalcSCEGradWithWeight(&grad_y, pred_y, ty, obj_weight, box_f); + CalcMSEGradWithWeight(&grad_w, pred_w, tw, obj_weight, box_f); + CalcMSEGradWithWeight(&grad_h, pred_h, th, obj_weight, box_f); + CalcSCEGradWithWeight(&grad_conf_target, pred_conf, tconf, obj_mask, + box_f); + CalcSCEGradWithWeight(&grad_conf_notarget, pred_conf, tconf, noobj_mask, + box_f); + CalcSCEGradWithWeight(&grad_class, pred_class, tclass, obj_mask_expand, + class_f); input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); AddAllGradToInputGrad( diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 7cf575d25..5fb4588e0 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -415,6 +415,7 @@ def yolov3_loss(x, anchors, class_num, ignore_thresh, + input_size, loss_weight_xy=None, loss_weight_wh=None, loss_weight_conf_target=None, @@ -436,6 +437,7 @@ def yolov3_loss(x, anchors (list|tuple): ${anchors_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} + input_size (int): ${input_size_comment} loss_weight_xy (float|None): ${loss_weight_xy_comment} loss_weight_wh (float|None): ${loss_weight_wh_comment} loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} @@ -490,6 +492,7 @@ def yolov3_loss(x, "anchors": anchors, "class_num": class_num, "ignore_thresh": ignore_thresh, + "input_size": input_size, } if loss_weight_xy is not None and isinstance(loss_weight_xy, float): diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 8723d9842..7d7556290 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -464,7 +464,7 @@ class TestYoloDetection(unittest.TestCase): gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, - 0.5) + 0.7, 416) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 544fe4b4f..07e7155bb 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -16,31 +16,22 @@ from __future__ import division import unittest import numpy as np +from scipy.special import logit +from scipy.special import expit from op_test import OpTest from paddle.fluid import core -def sigmoid(x): - return 1.0 / (1.0 + np.exp(-1.0 * x)) +def mse(x, y, weight, num): + return ((y - x)**2 * weight).sum() / num -def mse(x, y, num): - return ((y - x)**2).sum() / num - - -def bce(x, y, mask): - x = x.reshape((-1)) - y = y.reshape((-1)) - mask = mask.reshape((-1)) - - error_sum = 0.0 - count = 0 - for i in range(x.shape[0]): - if mask[i] > 0: - error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) - count += 1 - return error_sum / (-1.0 * count) +def sce(x, label, weight, num): + sigmoid_x = expit(x) + term1 = label * np.log(sigmoid_x) + term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) + return ((-term1 - term2) * weight).sum() / num def box_iou(box1, box2): @@ -66,11 +57,12 @@ def box_iou(box1, box2): return inter_area / (b1_area + b2_area + inter_area) -def build_target(gtboxs, gtlabel, attrs, grid_size): - n, b, _ = gtboxs.shape +def build_target(gtboxes, gtlabel, attrs, grid_size): + n, b, _ = gtboxes.shape ignore_thresh = attrs["ignore_thresh"] anchors = attrs["anchors"] class_num = attrs["class_num"] + input_size = attrs["input_size"] an_num = len(anchors) // 2 obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') @@ -78,20 +70,21 @@ def build_target(gtboxs, gtlabel, attrs, grid_size): ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tcls = np.zeros( (n, an_num, grid_size, grid_size, class_num)).astype('float32') for i in range(n): for j in range(b): - if gtboxs[i, j, :].sum() == 0: + if gtboxes[i, j, :].sum() == 0: continue gt_label = gtlabel[i, j] - gx = gtboxs[i, j, 0] * grid_size - gy = gtboxs[i, j, 1] * grid_size - gw = gtboxs[i, j, 2] * grid_size - gh = gtboxs[i, j, 3] * grid_size + gx = gtboxes[i, j, 0] * grid_size + gy = gtboxes[i, j, 1] * grid_size + gw = gtboxes[i, j, 2] * input_size + gh = gtboxes[i, j, 3] * input_size gi = int(gx) gj = int(gy) @@ -115,10 +108,12 @@ def build_target(gtboxs, gtlabel, attrs, grid_size): best_an_index]) th[i, best_an_index, gj, gi] = np.log( gh / anchors[2 * best_an_index + 1]) + tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[ + i, j, 2] * gtboxes[i, j, 3] tconf[i, best_an_index, gj, gi] = 1 tcls[i, best_an_index, gj, gi, gt_label] = 1 - return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask) def YoloV3Loss(x, gtbox, gtlabel, attrs): @@ -126,27 +121,28 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): an_num = len(attrs['anchors']) // 2 class_num = attrs["class_num"] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) - pred_x = sigmoid(x[:, :, :, :, 0]) - pred_y = sigmoid(x[:, :, :, :, 1]) + pred_x = x[:, :, :, :, 0] + pred_y = x[:, :, :, :, 1] pred_w = x[:, :, :, :, 2] pred_h = x[:, :, :, :, 3] - pred_conf = sigmoid(x[:, :, :, :, 4]) - pred_cls = sigmoid(x[:, :, :, :, 5:]) + pred_conf = x[:, :, :, :, 4] + pred_cls = x[:, :, :, :, 5:] - tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) + obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) - loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) - loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) - loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) - loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) - loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask, - noobj_mask) - loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, - obj_mask_expand) + box_f = an_num * h * w + class_f = an_num * h * w * class_num + loss_x = sce(pred_x, tx, obj_weight, box_f) + loss_y = sce(pred_y, ty, obj_weight, box_f) + loss_w = mse(pred_w, tw, obj_weight, box_f) + loss_h = mse(pred_h, th, obj_weight, box_f) + loss_conf_target = sce(pred_conf, tconf, obj_mask, box_f) + loss_conf_notarget = sce(pred_conf, tconf, noobj_mask, box_f) + loss_class = sce(pred_cls, tcls, obj_mask_expand, class_f) return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + attrs['loss_weight_wh'] * (loss_w + loss_h) \ @@ -164,7 +160,7 @@ class TestYolov3LossOp(OpTest): self.loss_weight_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' - x = np.random.random(size=self.x_shape).astype('float32') + x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]).astype('int32') @@ -173,6 +169,7 @@ class TestYolov3LossOp(OpTest): "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, + "input_size": self.input_size, "loss_weight_xy": self.loss_weight_xy, "loss_weight_wh": self.loss_weight_wh, "loss_weight_conf_target": self.loss_weight_conf_target, @@ -196,18 +193,19 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.06) + max_relative_error=0.3) def initTestCase(self): self.anchors = [10, 13, 12, 12] self.class_num = 10 - self.ignore_thresh = 0.5 + self.ignore_thresh = 0.7 + self.input_size = 416 self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) self.gtbox_shape = (5, 10, 4) - self.loss_weight_xy = 2.5 + self.loss_weight_xy = 1.4 self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 1.5 - self.loss_weight_conf_notarget = 0.5 + self.loss_weight_conf_target = 1.1 + self.loss_weight_conf_notarget = 0.9 self.loss_weight_class = 1.2 -- GitLab