From 2fbfef2ec9683ac18903ca8cf7cb69c5389ba3ba Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 13 Dec 2018 19:15:52 +0800 Subject: [PATCH] fix no box expression. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 46617472618..d0064a81902 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -152,13 +152,10 @@ static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, const T* label_data = label.data(); const T* weight_data = weight.data(); - // LOG(ERROR) << "SCE grad start"; for (int i = 0; i < n; i++) { for (int j = 0; j < stride; j++) { grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * weight_data[j] * loss_grad[i]; - // if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " << - // weight_data[j] << " " << loss_grad[i]; } grad_data += stride; x_data += stride; @@ -258,8 +255,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_box_t(i, j, 0)) && isZero(gt_box_t(i, j, 1)) && - isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + if (isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { continue; } @@ -425,12 +421,6 @@ class Yolov3LossKernel : public framework::OpKernel { loss_weight_conf_notarget, loss_data); CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, loss_data); - - // loss_data[0] = (loss_weight_xy * (loss_x + loss_y) + - // loss_weight_wh * (loss_w + loss_h) + - // loss_weight_conf_target * loss_conf_target + - // loss_weight_conf_notarget * loss_conf_notarget + - // loss_weight_class * loss_class) / n; } }; @@ -494,8 +484,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto tweight_t = EigenTensor::From(tweight); obj_weight_t = obj_mask_t * tweight_t; - // LOG(ERROR) << obj_mask_t; - Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); -- GitLab