diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 8c46e341d625098349000ea2aff6af004841fe9d..5b777f0448d3ddf286609fa635371bc6b1955739 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -29,6 +29,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTLabel) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("ObjectnessMask"), + "Output(ObjectnessMask) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("GTMatchMask"), + "Output(GTMatchMask) of Yolov3LossOp should not be null."); auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); @@ -68,6 +73,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { std::vector dim_out({dim_x[0]}); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); + + std::vector dim_obj_mask({dim_x[0], mask_num, dim_x[2], dim_x[3]}); + ctx->SetOutputDim("ObjectnessMask", framework::make_ddim(dim_obj_mask)); + + std::vector dim_gt_match_mask({dim_gtbox[0], dim_gtbox[1]}); + ctx->SetOutputDim("GTMatchMask", framework::make_ddim(dim_gt_match_mask)); } protected: @@ -103,6 +114,16 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); + AddOutput("ObjectnessMask", + "This is an intermediate tensor with shape of [N, M, H, W], " + "M is the number of anchor masks. This parameter caches the " + "mask for calculate objectness loss in gradient kernel.") + .AsIntermediate(); + AddOutput("GTMatchMask", + "This is an intermediate tensor with shape if [N, B], " + "B is the max box number of GT boxes. This parameter caches " + "matched mask index of each GT boxes for gradient calculate.") + .AsIntermediate(); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", @@ -208,6 +229,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + op->SetInput("ObjectnessMask", Output("ObjectnessMask")); + op->SetInput("GTMatchMask", Output("GTMatchMask")); op->SetAttrMap(Attrs()); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 12499befcacd0e6f8251192dcd34e96211bfb8b3..85d93cf96f9d4daca482bd7848b0cc8f7280f1de 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -227,6 +227,8 @@ class Yolov3LossKernel : public framework::OpKernel { auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); + auto* objness_mask = ctx.Output("ObjectnessMask"); + auto* gt_match_mask = ctx.Output("GTMatchMask"); auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); @@ -241,19 +243,19 @@ class Yolov3LossKernel : public framework::OpKernel { const int b = gt_box->dims()[1]; int input_size = downsample * h; + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); - - Tensor objness; - int* objness_data = - objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(objness_data, 0, objness.numel() * sizeof(int)); - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; + int* obj_mask_data = + objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(obj_mask_data, 0, objness_mask->numel() * sizeof(int)); + int* gt_match_mask_data = + gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); for (int i = 0; i < n; i++) { for (int j = 0; j < mask_num; j++) { @@ -277,7 +279,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - objness_data[obj_idx] = -1; + obj_mask_data[obj_idx] = -1; } } } @@ -285,6 +287,7 @@ class Yolov3LossKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { Box gt = GetGtBox(gt_box_data, i, b, t); if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + gt_match_mask_data[i * b + t] = -1; continue; } int gi = static_cast(gt.x * w); @@ -309,6 +312,7 @@ class Yolov3LossKernel : public framework::OpKernel { } int mask_idx = GetMaskIndex(anchor_mask, best_n); + gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); @@ -316,7 +320,7 @@ class Yolov3LossKernel : public framework::OpKernel { box_idx, gi, gj, h, input_size, stride); int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - objness_data[obj_idx] = 1; + obj_mask_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, @@ -327,7 +331,7 @@ class Yolov3LossKernel : public framework::OpKernel { } } - CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, + CalcObjnessLoss(loss_data, input_data + 4 * stride, obj_mask_data, n, mask_num, h, w, stride, an_stride); } }; @@ -341,64 +345,35 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* gt_label = ctx.Input("GTLabel"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + auto* objness_mask = ctx.Input("ObjectnessMask"); + auto* gt_match_mask = ctx.Input("GTMatchMask"); auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - const int an_num = anchors.size() / 2; + const int n = input_grad->dims()[0]; + const int c = input_grad->dims()[1]; + const int h = input_grad->dims()[2]; + const int w = input_grad->dims()[3]; const int mask_num = anchor_mask.size(); - const int b = gt_box->dims()[1]; + const int b = gt_match_mask->dims()[1]; int input_size = downsample * h; + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); const T* loss_grad_data = loss_grad->data(); + const int* obj_mask_data = objness_mask->data(); + const int* gt_match_mask_data = gt_match_mask->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); - Tensor objness; - int* objness_data = - objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(objness_data, 0, objness.numel() * sizeof(int)); - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; - for (int i = 0; i < n; i++) { - for (int j = 0; j < mask_num; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - int box_idx = - GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], - h, input_size, box_idx, stride); - T best_iou = 0; - for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { - continue; - } - T iou = CalcBoxIoU(pred, gt); - if (iou > best_iou) { - best_iou = iou; - } - } - - if (best_iou > ignore_thresh) { - int obj_idx = (i * mask_num + j) * stride + k * w + l; - objness_data[obj_idx] = -1; - } - } - } - } for (int t = 0; t < b; t++) { Box gt = GetGtBox(gt_box_data, i, b, t); if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { @@ -406,35 +381,14 @@ class Yolov3LossGradKernel : public framework::OpKernel { } int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); - Box gt_shift = gt; - gt_shift.x = 0.0; - gt_shift.y = 0.0; - T best_iou = 0.0; - int best_n = 0; - for (int an_idx = 0; an_idx < an_num; an_idx++) { - Box an_box; - an_box.x = 0.0; - an_box.y = 0.0; - an_box.w = anchors[2 * an_idx] / static_cast(input_size); - an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = CalcBoxIoU(an_box, gt_shift); - // TO DO: iou > 0.5 ? - if (iou > best_iou) { - best_iou = iou; - best_n = an_idx; - } - } - int mask_idx = GetMaskIndex(anchor_mask, best_n); + int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], - input_data, gt, anchors, best_n, box_idx, - gi, gj, h, input_size, stride); - - int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - objness_data[obj_idx] = 1; + CalcBoxLocationLossGrad( + input_grad_data, loss_grad_data[i], input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, @@ -446,7 +400,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { } CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, - input_data + 4 * stride, objness_data, n, mask_num, + input_data + 4 * stride, obj_mask_data, n, mask_num, h, w, stride, an_stride); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 542162b7f41b8c116625e7956c2a64d7711f85ea..90d112aa014c23d3c2dc1d98c80c96118628137a 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -483,6 +483,9 @@ def yolov3_loss(x, loss = helper.create_variable( name=name, dtype=x.dtype, persistable=False) + objectness_mask = helper.create_variable_for_type_inference(dtype='int32') + gt_match_mask = helper.create_variable_for_type_inference(dtype='int32') + attrs = { "anchors": anchors, "anchor_mask": anchor_mask, @@ -496,7 +499,11 @@ def yolov3_loss(x, inputs={"X": x, "GTBox": gtbox, "GTLabel": gtlabel}, - outputs={'Loss': loss}, + outputs={ + 'Loss': loss, + 'ObjectnessMask': objectness_mask, + 'GTMatchMask': gt_match_mask + }, attrs=attrs) return loss diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 188acea2b93fd8b2cc28d52389555c2b17f7e36d..904bee00c1284245b848855538d66b5cfb5dd0c2 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -116,13 +116,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) iou_matches = np.argmax(ious, axis=-1) + gt_matches = iou_matches.copy() for i in range(n): for j in range(b): if gtbox[i, j, 2:].sum() == 0: + gt_matches[i, j] = -1 continue if iou_matches[i, j] not in anchor_mask: + gt_matches[i, j] = -1 continue an_idx = anchor_mask.index(iou_matches[i, j]) + gt_matches[i, j] = an_idx gi = int(gtbox[i, j, 0] * w) gj = int(gtbox[i, j, 1] * h) @@ -146,7 +150,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): if objness[i, j] >= 0: loss[i] += sce(pred_obj[i, j], objness[i, j]) - return loss + return (loss, objness.reshape((n, mask_num, h, w)).astype('int32'), \ + gt_matches.astype('int32')) class TestYolov3LossOp(OpTest): @@ -173,11 +178,16 @@ class TestYolov3LossOp(OpTest): 'GTBox': gtbox.astype('float32'), 'GTLabel': gtlabel.astype('int32') } - self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + self.outputs = { + 'Loss': loss, + 'ObjectnessMask': objness, + "GTMatchMask": gt_matches + } def test_check_output(self): place = core.CPUPlace() - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=2e-3) def test_check_grad_ignore_gtbox(self): place = core.CPUPlace()