From cc01db6029c84b5e059d355b95dd73d18894594f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 28 Dec 2018 20:06:52 +0800 Subject: [PATCH] calc valid gt before loss calc. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 41 ++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 85d93cf96f..301e2f4033 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -219,6 +219,22 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, } } +template +static void inline GtValid(bool* valid, const T* gtbox, const int n, + const int b) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < b; j++) { + if (LessEqualZero(gtbox[j * 4 + 2]) || LessEqualZero(gtbox[j * 4 + 3])) { + valid[j] = false; + } else { + valid[j] = true; + } + } + valid += b; + gtbox += b * 4; + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -257,20 +273,28 @@ class Yolov3LossKernel : public framework::OpKernel { int* gt_match_mask_data = gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); + // calc valid gt box mask, avoid calc duplicately in following code + Tensor gt_valid_mask; + bool* gt_valid_mask_data = + gt_valid_mask.mutable_data({n, b}, ctx.GetPlace()); + GtValid(gt_valid_mask_data, gt_box_data, n, b); + for (int i = 0; i < n; i++) { for (int j = 0; j < mask_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { + // each predict box find a best match gt box, if overlap is bigger + // then ignore_thresh, ignore the objectness loss. int box_idx = GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], h, input_size, box_idx, stride); T best_iou = 0; for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + if (!gt_valid_mask_data[i * b + t]) { continue; } + Box gt = GetGtBox(gt_box_data, i, b, t); T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; @@ -281,15 +305,18 @@ class Yolov3LossKernel : public framework::OpKernel { int obj_idx = (i * mask_num + j) * stride + k * w + l; obj_mask_data[obj_idx] = -1; } + // TODO(dengkaipeng): all losses should be calculated if best IoU + // is bigger then truth thresh should be calculated here, but + // currently, truth thresh is an unreachable value as 1.0. } } } for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + if (!gt_valid_mask_data[i * b + t]) { gt_match_mask_data[i * b + t] = -1; continue; } + Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -297,6 +324,9 @@ class Yolov3LossKernel : public framework::OpKernel { gt_shift.y = 0.0; T best_iou = 0.0; int best_n = 0; + // each gt box find a best match anchor box as positive sample, + // for positive sample, all losses should be calculated, and for + // other samples, only objectness loss is required. for (int an_idx = 0; an_idx < an_num; an_idx++) { Box an_box; an_box.x = 0.0; @@ -304,7 +334,8 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); float iou = CalcBoxIoU(an_box, gt_shift); - // TO DO: iou > 0.5 ? + // TODO(dengkaipeng): In paper, objectness loss is ignore when + // best IoU > 0.5, but darknet code didn't implement this. if (iou > best_iou) { best_iou = iou; best_n = an_idx; -- GitLab