From 36c46152e140adab7e74eaeee9dbeccb65fc5633 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sun, 11 Nov 2018 23:52:36 +0800 Subject: [PATCH] Add unittest for yolov3_loss. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 25 +-- paddle/fluid/operators/yolov3_loss_op.h | 67 +++--- python/paddle/fluid/layers/nn.py | 28 +++ .../tests/unittests/test_yolov3_loss_op.py | 194 ++++++++++++++++++ 4 files changed, 273 insertions(+), 41 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 9ed7e13dc..7369ce31e 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -34,7 +34,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gt = ctx->GetInputDim("GTBox"); auto img_height = ctx->Attrs().Get("img_height"); auto anchors = ctx->Attrs().Get>("anchors"); - auto box_num = ctx->Attrs().Get("box_num"); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], @@ -50,8 +49,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, "Attr(anchors) length should be even integer."); - PADDLE_ENFORCE_GT(box_num, 0, - "Attr(box_num) should be an integer greater then 0."); PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); @@ -73,23 +70,19 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of bilinear interpolation, " "This is a 4-D tensor with shape of [N, C, H, W]"); - AddInput( - "GTBox", - "The input tensor of ground truth boxes, " - "This is a 3-D tensor with shape of [N, max_box_num, 5 + class_num], " - "max_box_num is the max number of boxes in each image, " - "class_num is the number of classes in data set. " - "In the third dimention, stores x, y, w, h, confidence, classes " - "one-hot key. " - "x, y is the center cordinate of boxes and w, h is the width and " - "height, " - "and all of them should be divided by input image height to scale to " - "[0, 1]."); + AddInput("GTBox", + "The input tensor of ground truth boxes, " + "This is a 3-D tensor with shape of [N, max_box_num, 5], " + "max_box_num is the max number of boxes in each image, " + "In the third dimention, stores label, x, y, w, h, " + "label is an integer to specify box class, x, y is the " + "center cordinate of boxes and w, h is the width and height" + "and x, y, w, h should be divided by input image height to " + "scale to [0, 1]."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [1]"); - AddAttr("box_num", "The number of boxes generated in each grid."); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", "The anchor width and height, " diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index a796a5780..426e0688a 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -25,8 +25,7 @@ template using EigenVector = framework::EigenVector; -using Array2 = Eigen::DSizes; -using Array4 = Eigen::DSizes; +using Array5 = Eigen::DSizes; template static inline bool isZero(T x) { @@ -43,7 +42,7 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { auto x_t = EigenVector::Flatten(x); auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); + auto mask_t = EigenVector::Flatten(mask); T error_sum = 0.0; T points = 0.0; @@ -61,7 +60,7 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { auto x_t = EigenVector::Flatten(x); auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); + auto mask_t = EigenVector::Flatten(mask); T error_sum = 0.0; T points = 0.0; @@ -89,7 +88,6 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs, const int box_attr_num = 5 + class_num; auto input_t = EigenTensor::From(input); - // auto pred_boxes_t = EigenTensor::From(*pred_boxes); auto pred_confs_t = EigenTensor::From(*pred_confs); auto pred_classes_t = EigenTensor::From(*pred_classes); auto pred_x_t = EigenTensor::From(*pred_x); @@ -113,13 +111,6 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs, pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); - // pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k; - // pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j; - // pred_boxes_t(i, an_idx, j, k, 2) = - // exp(pred_w_t(i, an_idx, j, k)) * an_w; - // pred_boxes_t(i, an_idx, j, k, 3) = - // exp(pred_h_t(i, an_idx, j, k)) * an_h; - pred_confs_t(i, an_idx, j, k) = sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); @@ -199,7 +190,7 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, continue; } - int gt_label = gt_boxes_t(i, j, 0); + int gt_label = static_cast(gt_boxes_t(i, j, 0)); T gx = gt_boxes_t(i, j, 1) * grid_size; T gy = gt_boxes_t(i, j, 2) * grid_size; T gw = gt_boxes_t(i, j, 3) * grid_size; @@ -207,7 +198,7 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, int gi = static_cast(gx); int gj = static_cast(gy); - T max_iou = static_cast(-1); + T max_iou = static_cast(0); T iou; int best_an_index = -1; std::vector gt_box({0, 0, gw, gh}); @@ -220,20 +211,33 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(b, an_idx, gj, gi) = 0; + noobj_mask_t(i, an_idx, gj, gi) = 0; } } - obj_mask_t(b, best_an_index, gj, gi) = 1; - noobj_mask_t(b, best_an_index, gj, gi) = 1; + obj_mask_t(i, best_an_index, gj, gi) = 1; + noobj_mask_t(i, best_an_index, gj, gi) = 0; tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); - tclass_t(b, best_an_index, gj, gi, gt_label) = 1; - tconf_t(b, best_an_index, gj, gi) = 1; + tclass_t(i, best_an_index, gj, gi, gt_label) = 1; + tconf_t(i, best_an_index, gj, gi) = 1; } } - noobj_mask_t = noobj_mask_t - obj_mask_t; +} + +static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, + const Tensor& obj_mask) { + const int n = obj_mask_expand->dims()[0]; + const int an_num = obj_mask_expand->dims()[1]; + const int h = obj_mask_expand->dims()[2]; + const int w = obj_mask_expand->dims()[3]; + const int class_num = obj_mask_expand->dims()[4]; + auto obj_mask_expand_t = EigenTensor::From(*obj_mask_expand); + auto obj_mask_t = EigenTensor::From(obj_mask); + + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); } template @@ -280,17 +284,30 @@ class Yolov3LossKernel : public framework::OpKernel { &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_true = CalcBCEWithMask(pred_confs, tconf, obj_mask); - T loss_conf_false = CalcBCEWithMask(pred_confs, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask); + T loss_conf_obj = CalcBCEWithMask(pred_confs, tconf, obj_mask); + T loss_conf_noobj = CalcBCEWithMask(pred_confs, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask_expand); + + // LOG(ERROR) << "loss_x: " << loss_x; + // LOG(ERROR) << "loss_y: " << loss_y; + // LOG(ERROR) << "loss_w: " << loss_w; + // LOG(ERROR) << "loss_h: " << loss_h; + // LOG(ERROR) << "loss_conf_obj: " << loss_conf_obj; + // LOG(ERROR) << "loss_conf_noobj: " << loss_conf_noobj; + // LOG(ERROR) << "loss_class: " << loss_class; auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); - loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_true + - loss_conf_false + loss_class; + loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj + + loss_conf_noobj + loss_class; } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d3623464e..1ee7198f2 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -164,6 +164,7 @@ __all__ = [ 'hash', 'grid_sampler', 'log_loss', + 'yolov3_loss', 'add_position_encoding', 'bilinear_tensor_product', ] @@ -8243,6 +8244,33 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss +def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None): + """ + **YOLOv3 Loss Layer** + + This layer + """ + helper = LayerHelper('yolov3_loss', **locals()) + + if name is None: + loss = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + loss = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type='yolov3_loss', + inputs={'X': x, + "GTBox": gtbox}, + outputs={'Loss': loss}, + attrs={ + "img_height": img_height, + "anchors": anchors, + "ignore_thresh": ignore_thresh, + }) + return loss + + def add_position_encoding(input, alpha, beta, name=None): """ **Add Position Encoding Layer** diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py new file mode 100644 index 000000000..f5b15efb2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -0,0 +1,194 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def mse(x, y, num): + return ((y - x)**2).sum() / num + + +def bce(x, y, mask): + x = x.reshape((-1)) + y = y.reshape((-1)) + mask = mask.reshape((-1)) + + error_sum = 0.0 + count = 0 + for i in range(x.shape[0]): + if mask[i] > 0: + error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) + count += 1 + return error_sum / (-1.0 * count) + + +def box_iou(box1, box2): + b1_x1 = box1[0] - box1[2] / 2 + b1_x2 = box1[0] + box1[2] / 2 + b1_y1 = box1[1] - box1[3] / 2 + b1_y2 = box1[1] + box1[3] / 2 + b2_x1 = box2[0] - box2[2] / 2 + b2_x2 = box2[0] + box2[2] / 2 + b2_y1 = box2[1] - box2[3] / 2 + b2_y2 = box2[1] + box2[3] / 2 + + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + + inter_rect_x1 = max(b1_x1, b2_x1) + inter_rect_y1 = max(b1_y1, b2_y1) + inter_rect_x2 = min(b1_x2, b2_x2) + inter_rect_y2 = min(b1_y2, b2_y2) + inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max( + inter_rect_y2 - inter_rect_y1, 0) + + return inter_area / (b1_area + b2_area + inter_area) + + +def build_target(gtboxs, attrs, grid_size): + n, b, _ = gtboxs.shape + ignore_thresh = attrs["ignore_thresh"] + img_height = attrs["img_height"] + anchors = attrs["anchors"] + class_num = attrs["class_num"] + an_num = len(anchors) / 2 + obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') + tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tcls = np.zeros( + (n, an_num, grid_size, grid_size, class_num)).astype('float32') + + for i in range(n): + for j in range(b): + if gtboxs[i, j, :].sum() == 0: + continue + + gt_label = int(gtboxs[i, j, 0]) + gx = gtboxs[i, j, 1] * grid_size + gy = gtboxs[i, j, 2] * grid_size + gw = gtboxs[i, j, 3] * grid_size + gh = gtboxs[i, j, 4] * grid_size + + gi = int(gx) + gj = int(gy) + + gtbox = [0, 0, gw, gh] + max_iou = 0 + for k in range(an_num): + anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]] + iou = box_iou(gtbox, anchor_box) + if iou > max_iou: + max_iou = iou + best_an_index = k + if iou > ignore_thresh: + noobj_mask[i, best_an_index, gj, gi] = 0 + + obj_mask[i, best_an_index, gj, gi] = 1 + noobj_mask[i, best_an_index, gj, gi] = 0 + tx[i, best_an_index, gj, gi] = gx - gi + ty[i, best_an_index, gj, gi] = gy - gj + tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * + best_an_index]) + th[i, best_an_index, gj, gi] = np.log( + gh / anchors[2 * best_an_index + 1]) + tconf[i, best_an_index, gj, gi] = 1 + tcls[i, best_an_index, gj, gi, gt_label] = 1 + + return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + + +def YoloV3Loss(x, gtbox, attrs): + n, c, h, w = x.shape + an_num = len(attrs['anchors']) / 2 + class_num = attrs["class_num"] + x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + pred_x = sigmoid(x[:, :, :, :, 0]) + pred_y = sigmoid(x[:, :, :, :, 1]) + pred_w = x[:, :, :, :, 2] + pred_h = x[:, :, :, :, 3] + pred_conf = sigmoid(x[:, :, :, :, 4]) + pred_cls = sigmoid(x[:, :, :, :, 5:]) + + tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + gtbox, attrs, x.shape[2]) + + obj_mask_expand = np.tile( + np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) + loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) + loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) + loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) + loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) + loss_conf_obj = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) + loss_conf_noobj = bce(pred_conf * noobj_mask, tconf * noobj_mask, + noobj_mask) + loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, + obj_mask_expand) + # print "loss_x: ", loss_x + # print "loss_y: ", loss_y + # print "loss_w: ", loss_w + # print "loss_h: ", loss_h + # print "loss_conf_obj: ", loss_conf_obj + # print "loss_conf_noobj: ", loss_conf_noobj + # print "loss_class: ", loss_class + + return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class + + +class TestYolov3LossOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'yolov3_loss' + x = np.random.random(size=self.x_shape).astype('float32') + gtbox = np.random.random(size=self.gtbox_shape).astype('float32') + gtbox[:, :, 0] = np.random.randint(0, self.class_num, + self.gtbox_shape[:2]) + + self.attrs = { + "img_height": self.img_height, + "anchors": self.anchors, + "class_num": self.class_num, + "ignore_thresh": self.ignore_thresh, + } + + self.inputs = {'X': x, 'GTBox': gtbox} + self.outputs = {'Loss': np.array([YoloV3Loss(x, gtbox, self.attrs)])} + print self.outputs + + def test_check_output(self): + self.check_output(atol=1e-3) + + # def test_check_grad_normal(self): + # self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61) + + def initTestCase(self): + self.img_height = 608 + self.anchors = [10, 13, 16, 30, 33, 23] + self.class_num = 10 + self.ignore_thresh = 0.5 + self.x_shape = (5, len(self.anchors) / 2 * (5 + self.class_num), 7, 7) + self.gtbox_shape = (5, 10, 5) + + +if __name__ == "__main__": + unittest.main() -- GitLab