From 8218e30176c6bdaccd11cd0141c6f47878233b54 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 4 Jan 2019 11:40:08 +0800 Subject: [PATCH] add gtscore. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 20 +++++++++++++++-- paddle/fluid/operators/yolov3_loss_op.h | 22 ++++++++++++------- python/paddle/fluid/layers/detection.py | 17 ++++++++++---- .../tests/unittests/test_yolov3_loss_op.py | 19 +++++++++------- 5 files changed, 57 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c6ac9c7ea..bf0916a076 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 5b777f0448..c146035f9d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -27,6 +27,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTLabel"), "Input(GTLabel) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTScore"), + "Input(GTScore) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); PADDLE_ENFORCE( @@ -38,6 +40,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); + auto dim_gtscore = ctx->GetInputDim("GTScore"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); @@ -54,11 +57,17 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2, - "Input(GTBox) should be a 2-D tensor"); + "Input(GTLabel) should be a 2-D tensor"); PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0], "Input(GTBox) and Input(GTLabel) dim[0] should be same"); PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], "Input(GTBox) and Input(GTLabel) dim[1] should be same"); + PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2, + "Input(GTScore) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtscore[0], dim_gtbox[0], + "Input(GTBox) and Input(GTScore) dim[0] should be same"); + PADDLE_ENFORCE_EQ(dim_gtscore[1], dim_gtbox[1], + "Input(GTBox) and Input(GTScore) dim[1] should be same"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -109,8 +118,13 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("GTLabel", "The input tensor of ground truth label, " "This is a 2-D tensor with shape of [N, max_box_num], " - "and each element shoudl be an integer to indicate the " + "and each element should be an integer to indicate the " "box class id."); + AddInput("GTScore", + "The score of GTLabel, This is a 2-D tensor in same shape " + "GTLabel, and score values should in range (0, 1). This " + "input is for GTLabel score can be not 1.0 in image mixup " + "augmentation."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); @@ -228,6 +242,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); + op->SetInput("GTScore", Input("GTScore")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("GTMatchMask", Output("GTMatchMask")); @@ -237,6 +252,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {}); + op->SetOutput(framework::GradVarName("GTScore"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 34119b1a02..c4095b8ca5 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -156,25 +156,25 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, - const int label, const int class_num, - const int stride) { + const int label, const T score, + const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); + loss[0] += SCE(pred, (i == label) ? score : 0.0); } } template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, - const int label, const int class_num, - const int stride) { + const int label, const T score, + const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] : 1.0 / class_num; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; + SCEGrad(pred, (i == label) ? score : 0.0) * loss; } } @@ -246,6 +246,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* loss = ctx.Output("Loss"); auto* objness_mask = ctx.Output("ObjectnessMask"); auto* gt_match_mask = ctx.Output("GTMatchMask"); @@ -269,6 +270,7 @@ class Yolov3LossKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); int* obj_mask_data = @@ -358,9 +360,10 @@ class Yolov3LossKernel : public framework::OpKernel { obj_mask_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; + T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); - CalcLabelLoss(loss_data + i, input_data, label_idx, label, + CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, class_num, stride); } } @@ -378,6 +381,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -401,6 +405,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const int* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -423,10 +428,11 @@ class Yolov3LossGradKernel : public framework::OpKernel { anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; + T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, class_num, stride); + label_idx, label, score, class_num, stride); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 90d112aa01..10573cc4c6 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -412,6 +412,7 @@ def polygon_box_transform(input, name=None): def yolov3_loss(x, gtbox, gtlabel, + gtscore, anchors, anchor_mask, class_num, @@ -428,8 +429,10 @@ def yolov3_loss(x, and x, y, w, h should be relative value of input image. N is the batch number and B is the max box number in an image. - gtlabel (Variable): class id of ground truth boxes, shoud be ins shape + gtlabel (Variable): class id of ground truth boxes, shoud be in shape of [N, B]. + gtscore (Variable): score of gtlabel, should be in same shape with gtlabel + and score value in range (0, 1). anchors (list|tuple): ${anchors_comment} anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} @@ -444,6 +447,7 @@ def yolov3_loss(x, TypeError: Input x of yolov3_loss must be Variable TypeError: Input gtbox of yolov3_loss must be Variable" TypeError: Input gtlabel of yolov3_loss must be Variable" + TypeError: Input gtscore of yolov3_loss must be Variable" TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number @@ -467,6 +471,8 @@ def yolov3_loss(x, raise TypeError("Input gtbox of yolov3_loss must be Variable") if not isinstance(gtlabel, Variable): raise TypeError("Input gtlabel of yolov3_loss must be Variable") + if not isinstance(gtscore, Variable): + raise TypeError("Input gtscore of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): @@ -496,9 +502,12 @@ def yolov3_loss(x, helper.append_op( type='yolov3_loss', - inputs={"X": x, - "GTBox": gtbox, - "GTLabel": gtlabel}, + inputs={ + "X": x, + "GTBox": gtbox, + "GTLabel": gtlabel, + "GTScore": gtscore + }, outputs={ 'Loss': loss, 'ObjectnessMask': objectness_mask, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 27fb92c589..c65570d7c1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): return inter_area / union -def YOLOv3Loss(x, gtbox, gtlabel, attrs): +def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): n, c, h, w = x.shape b = gtbox.shape[1] anchors = attrs['anchors'] @@ -148,7 +148,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): for label_idx in range(class_num): loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], - int(label_idx == gtlabel[i, j])) + int(label_idx == gtlabel[i, j]) * gtscore[i, j]) for j in range(mask_num * h * w): if objness[i, j] >= 0: @@ -165,6 +165,7 @@ class TestYolov3LossOp(OpTest): x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) + gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32') gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) gtbox = gtbox * gtmask[:, :, np.newaxis] gtlabel = gtlabel * gtmask @@ -180,9 +181,11 @@ class TestYolov3LossOp(OpTest): self.inputs = { 'X': x, 'GTBox': gtbox.astype('float32'), - 'GTLabel': gtlabel.astype('int32') + 'GTLabel': gtlabel.astype('int32'), + 'GTScore': gtscore.astype('float32') } - loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore, + self.attrs) self.outputs = { 'Loss': loss, 'ObjectnessMask': objness, @@ -198,8 +201,8 @@ class TestYolov3LossOp(OpTest): self.check_grad_with_place( place, ['X'], 'Loss', - no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.15) + no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), + max_relative_error=0.2) def initTestCase(self): self.anchors = [ @@ -207,11 +210,11 @@ class TestYolov3LossOp(OpTest): 373, 326 ] self.anchor_mask = [0, 1, 2] - self.class_num = 5 + self.class_num = 10 self.ignore_thresh = 0.7 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 5, 4) + self.gtbox_shape = (3, 10, 4) if __name__ == "__main__": -- GitLab