diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 4f286049d40af988e83b0d0bc24a7b61ee90444f..14996e3d86a0908cbef3aa0b8b378abb4be7bed2 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -299,6 +299,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7597f732430038a4a180297e730340d1bc47b8c --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -0,0 +1,221 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/yolov3_loss_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class Yolov3LossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTBox"), + "Input(GTBox) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTLabel"), + "Input(GTLabel) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of Yolov3LossOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_gtbox = ctx->GetInputDim("GTBox"); + auto dim_gtlabel = ctx->GetInputDim("GTLabel"); + auto anchors = ctx->Attrs().Get>("anchors"); + auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); + PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], + "Input(X) dim[3] and dim[4] should be euqal."); + PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_number * (5 " + "+ class_num))."); + PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, + "Input(GTBox) should be a 3-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); + PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2, + "Input(GTBox) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0], + "Input(GTBox) and Input(GTLabel) dim[0] should be same"); + PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], + "Input(GTBox) and Input(GTLabel) dim[1] should be same"); + PADDLE_ENFORCE_GT(anchors.size(), 0, + "Attr(anchors) length should be greater then 0."); + PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, + "Attr(anchors) length should be even integer."); + PADDLE_ENFORCE_GT(class_num, 0, + "Attr(class_num) should be an integer greater then 0."); + + std::vector dim_out({1}); + ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of YOLO v3 loss operator, " + "This is a 4-D tensor with shape of [N, C, H, W]." + "H and W should be same, and the second dimention(C) stores" + "box locations, confidence score and classification one-hot" + "key of each anchor box"); + AddInput("GTBox", + "The input tensor of ground truth boxes, " + "This is a 3-D tensor with shape of [N, max_box_num, 5], " + "max_box_num is the max number of boxes in each image, " + "In the third dimention, stores x, y, w, h coordinates, " + "x, y is the center cordinate of boxes and w, h is the " + "width and height and x, y, w, h should be divided by " + "input image height to scale to [0, 1]."); + AddInput("GTLabel", + "The input tensor of ground truth label, " + "This is a 2-D tensor with shape of [N, max_box_num], " + "and each element shoudl be an integer to indicate the " + "box class id."); + AddOutput("Loss", + "The output yolov3 loss tensor, " + "This is a 1-D tensor with shape of [1]"); + + AddAttr("class_num", "The number of classes to predict."); + AddAttr>("anchors", + "The anchor width and height, " + "it will be parsed pair by pair."); + AddAttr("ignore_thresh", + "The ignore threshold to ignore confidence loss."); + AddAttr("loss_weight_xy", "The weight of x, y location loss.") + .SetDefault(1.0); + AddAttr("loss_weight_wh", "The weight of w, h location loss.") + .SetDefault(1.0); + AddAttr( + "loss_weight_conf_target", + "The weight of confidence score loss in locations with target object.") + .SetDefault(1.0); + AddAttr("loss_weight_conf_notarget", + "The weight of confidence score loss in locations without " + "target object.") + .SetDefault(1.0); + AddAttr("loss_weight_class", "The weight of classification loss.") + .SetDefault(1.0); + AddComment(R"DOC( + This operator generate yolov3 loss by given predict result and ground + truth boxes. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, specify the grid size, each grid point predict given + number boxes, this given number is specified by anchors, it should be + half anchors length, which following will be represented as S. In the + second dimention(the channel dimention), C should be S * (class_num + 5), + class_num is the box categoriy number of source dataset(such as coco), + so in the second dimention, stores 4 box location coordinates x, y, w, h + and confidence score of the box and class one-hot key of each anchor box. + + While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions + correspnd to: + + $$ + b_x = \sigma(t_x) + c_x + b_y = \sigma(t_y) + c_y + b_w = p_w e^{t_w} + b_h = p_h e^{t_h} + $$ + + While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$ + is specified by anchors. + + As for confidence score, it is the logistic regression value of IoU between + anchor boxes and ground truth boxes, the score of the anchor box which has + the max IoU should be 1, and if the anchor box has IoU bigger then ignore + thresh, the confidence score loss of this anchor box will be ignored. + + Therefore, the yolov3 loss consist of three major parts, box location loss, + confidence score loss, and classification loss. The MSE loss is used for + box location, and binary cross entropy loss is used for confidence score + loss and classification loss. + + Final loss will be represented as follow. + + $$ + loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh} + + \loss_weight_{conf_target} * loss_{conf_target} + + \loss_weight_{conf_notarget} * loss_{conf_notarget} + + \loss_weight_{class} * loss_{class} + $$ + )DOC"); + } +}; + +class Yolov3LossOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("yolov3_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("GTBox", Input("GTBox")); + op->SetInput("GTLabel", Input("GTLabel")); + op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("GTBox"), {}); + op->SetOutput(framework::GradVarName("GTLabel"), {}); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, + ops::Yolov3LossGradMaker); +REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); +REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0bb285722ddedf721d98237760ec9868e2134442 --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -0,0 +1,483 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenTensor = framework::EigenTensor; +template +using EigenVector = framework::EigenVector; + +using Array5 = Eigen::DSizes; + +template +static inline bool isZero(T x) { + return fabs(x) < 1e-6; +} + +template +static inline T sigmoid(T x) { + return 1.0 / (exp(-1.0 * x) + 1.0); +} + +template +static inline T CalcMaskPointNum(const Tensor& mask) { + auto mask_t = EigenVector::Flatten(mask); + T count = 0.0; + for (int i = 0; i < mask_t.dimensions()[0]; i++) { + if (mask_t(i)) { + count += 1.0; + } + } + return count; +} + +template +static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + T error_sum = 0.0; + T points = 0.0; + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + error_sum += pow(x_t(i) - y_t(i), 2); + points += 1; + } + } + return (error_sum / points); +} + +template +static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, + const Tensor& mask, T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; + } + } +} + +template +static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + T error_sum = 0.0; + T points = 0.0; + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + error_sum += + -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i))); + points += 1; + } + } + return (error_sum / points); +} + +template +static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& mask, + T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; + } + } +} + +template +static void CalcPredResult(const Tensor& input, Tensor* pred_conf, + Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, const int anchor_num, + const int class_num) { + const int n = input.dims()[0]; + const int h = input.dims()[2]; + const int w = input.dims()[3]; + const int box_attr_num = 5 + class_num; + + auto input_t = EigenTensor::From(input); + auto pred_conf_t = EigenTensor::From(*pred_conf); + auto pred_class_t = EigenTensor::From(*pred_class); + auto pred_x_t = EigenTensor::From(*pred_x); + auto pred_y_t = EigenTensor::From(*pred_y); + auto pred_w_t = EigenTensor::From(*pred_w); + auto pred_h_t = EigenTensor::From(*pred_h); + + for (int i = 0; i < n; i++) { + for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + for (int j = 0; j < h; j++) { + for (int k = 0; k < w; k++) { + pred_x_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx, j, k)); + pred_y_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); + pred_w_t(i, an_idx, j, k) = + input_t(i, box_attr_num * an_idx + 2, j, k); + pred_h_t(i, an_idx, j, k) = + input_t(i, box_attr_num * an_idx + 3, j, k); + + pred_conf_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); + + for (int c = 0; c < class_num; c++) { + pred_class_t(i, an_idx, j, k, c) = + sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + } + } + } + } + } +} + +template +static T CalcBoxIoU(std::vector box1, std::vector box2) { + T b1_x1 = box1[0] - box1[2] / 2; + T b1_x2 = box1[0] + box1[2] / 2; + T b1_y1 = box1[1] - box1[3] / 2; + T b1_y2 = box1[1] + box1[3] / 2; + T b2_x1 = box2[0] - box2[2] / 2; + T b2_x2 = box2[0] + box2[2] / 2; + T b2_y1 = box2[1] - box2[3] / 2; + T b2_y2 = box2[1] + box2[3] / 2; + + T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1); + T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1); + + T inter_rect_x1 = std::max(b1_x1, b2_x1); + T inter_rect_y1 = std::max(b1_y1, b2_y1); + T inter_rect_x2 = std::min(b1_x2, b2_x2); + T inter_rect_y2 = std::min(b1_y2, b2_y2); + T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast(0.0)) * + std::max(inter_rect_y2 - inter_rect_y1, static_cast(0.0)); + + return inter_area / (b1_area + b2_area - inter_area); +} + +template +static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, + const float ignore_thresh, std::vector anchors, + const int grid_size, Tensor* obj_mask, + Tensor* noobj_mask, Tensor* tx, Tensor* ty, + Tensor* tw, Tensor* th, Tensor* tconf, + Tensor* tclass) { + const int n = gt_box.dims()[0]; + const int b = gt_box.dims()[1]; + const int anchor_num = anchors.size() / 2; + auto gt_box_t = EigenTensor::From(gt_box); + auto gt_label_t = EigenTensor::From(gt_label); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); + auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto tx_t = EigenTensor::From(*tx).setConstant(0.0); + auto ty_t = EigenTensor::From(*ty).setConstant(0.0); + auto tw_t = EigenTensor::From(*tw).setConstant(0.0); + auto th_t = EigenTensor::From(*th).setConstant(0.0); + auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); + auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < b; j++) { + if (isZero(gt_box_t(i, j, 0)) && isZero(gt_box_t(i, j, 1)) && + isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + continue; + } + + int cur_label = gt_label_t(i, j); + T gx = gt_box_t(i, j, 0) * grid_size; + T gy = gt_box_t(i, j, 1) * grid_size; + T gw = gt_box_t(i, j, 2) * grid_size; + T gh = gt_box_t(i, j, 3) * grid_size; + int gi = static_cast(gx); + int gj = static_cast(gy); + + T max_iou = static_cast(0); + T iou; + int best_an_index = -1; + std::vector gt_box_shape({0, 0, gw, gh}); + for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), + static_cast(anchors[2 * an_idx + 1])}); + iou = CalcBoxIoU(gt_box_shape, anchor_shape); + if (iou > max_iou) { + max_iou = iou; + best_an_index = an_idx; + } + if (iou > ignore_thresh) { + noobj_mask_t(i, an_idx, gj, gi) = 0; + } + } + obj_mask_t(i, best_an_index, gj, gi) = 1; + noobj_mask_t(i, best_an_index, gj, gi) = 0; + tx_t(i, best_an_index, gj, gi) = gx - gi; + ty_t(i, best_an_index, gj, gi) = gy - gj; + tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); + th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); + tclass_t(i, best_an_index, gj, gi, cur_label) = 1; + tconf_t(i, best_an_index, gj, gi) = 1; + } + } +} + +static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, + const Tensor& obj_mask) { + const int n = obj_mask_expand->dims()[0]; + const int an_num = obj_mask_expand->dims()[1]; + const int h = obj_mask_expand->dims()[2]; + const int w = obj_mask_expand->dims()[3]; + const int class_num = obj_mask_expand->dims()[4]; + auto obj_mask_expand_t = EigenTensor::From(*obj_mask_expand); + auto obj_mask_t = EigenTensor::From(obj_mask); + + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); +} + +template +static void AddAllGradToInputGrad( + Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, + const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, + const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, + const Tensor& grad_conf_target, const Tensor& grad_conf_notarget, + const Tensor& grad_class, const int class_num, const float loss_weight_xy, + const float loss_weight_wh, const float loss_weight_conf_target, + const float loss_weight_conf_notarget, const float loss_weight_class) { + const int n = pred_x.dims()[0]; + const int an_num = pred_x.dims()[1]; + const int h = pred_x.dims()[2]; + const int w = pred_x.dims()[3]; + const int attr_num = class_num + 5; + auto grad_t = EigenTensor::From(*grad).setConstant(0.0); + auto pred_x_t = EigenTensor::From(pred_x); + auto pred_y_t = EigenTensor::From(pred_y); + auto pred_conf_t = EigenTensor::From(pred_conf); + auto pred_class_t = EigenTensor::From(pred_class); + auto grad_x_t = EigenTensor::From(grad_x); + auto grad_y_t = EigenTensor::From(grad_y); + auto grad_w_t = EigenTensor::From(grad_w); + auto grad_h_t = EigenTensor::From(grad_h); + auto grad_conf_target_t = EigenTensor::From(grad_conf_target); + auto grad_conf_notarget_t = EigenTensor::From(grad_conf_notarget); + auto grad_class_t = EigenTensor::From(grad_class); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + grad_t(i, j * attr_num, k, l) = + grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * + (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy; + grad_t(i, j * attr_num + 1, k, l) = + grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * + (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy; + grad_t(i, j * attr_num + 2, k, l) = + grad_w_t(i, j, k, l) * loss * loss_weight_wh; + grad_t(i, j * attr_num + 3, k, l) = + grad_h_t(i, j, k, l) * loss * loss_weight_wh; + grad_t(i, j * attr_num + 4, k, l) = + grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target; + grad_t(i, j * attr_num + 4, k, l) += + grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss * + loss_weight_conf_notarget; + + for (int c = 0; c < class_num; c++) { + grad_t(i, j * attr_num + 5 + c, k, l) = + grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * + (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class; + } + } + } + } + } +} + +template +class Yolov3LossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* gt_box = ctx.Input("GTBox"); + auto* gt_label = ctx.Input("GTLabel"); + auto* loss = ctx.Output("Loss"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + float loss_weight_xy = ctx.Attr("loss_weight_xy"); + float loss_weight_wh = ctx.Attr("loss_weight_wh"); + float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); + float loss_weight_conf_notarget = + ctx.Attr("loss_weight_conf_notarget"); + float loss_weight_class = ctx.Attr("loss_weight_class"); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_conf, pred_class; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); + + Tensor obj_mask, noobj_mask; + Tensor tx, ty, tw, th, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + + T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); + T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); + T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); + T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); + T loss_conf_target = CalcBCEWithMask(pred_conf, tconf, obj_mask); + T loss_conf_notarget = CalcBCEWithMask(pred_conf, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); + + auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); + loss_data[0] = loss_weight_xy * (loss_x + loss_y) + + loss_weight_wh * (loss_w + loss_h) + + loss_weight_conf_target * loss_conf_target + + loss_weight_conf_notarget * loss_conf_notarget + + loss_weight_class * loss_class; + } +}; + +template +class Yolov3LossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* gt_box = ctx.Input("GTBox"); + auto* gt_label = ctx.Input("GTLabel"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Loss")); + const T loss = output_grad->data()[0]; + float loss_weight_xy = ctx.Attr("loss_weight_xy"); + float loss_weight_wh = ctx.Attr("loss_weight_wh"); + float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); + float loss_weight_conf_notarget = + ctx.Attr("loss_weight_conf_notarget"); + float loss_weight_class = ctx.Attr("loss_weight_class"); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_conf, pred_class; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); + + Tensor obj_mask, noobj_mask; + Tensor tx, ty, tw, th, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + + Tensor grad_x, grad_y, grad_w, grad_h; + Tensor grad_conf_target, grad_conf_notarget, grad_class; + grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + T obj_mf = CalcMaskPointNum(obj_mask); + T noobj_mf = CalcMaskPointNum(noobj_mask); + T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); + CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); + CalcBCEGradWithMask(&grad_conf_target, pred_conf, tconf, obj_mask, + obj_mf); + CalcBCEGradWithMask(&grad_conf_notarget, pred_conf, tconf, noobj_mask, + noobj_mf); + CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, + obj_expand_mf); + + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + AddAllGradToInputGrad( + input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, + grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class, + class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target, + loss_weight_conf_notarget, loss_weight_class); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 4843af8340310e0f47964d41708b13216fcd2161..ce731f39ea099a4d8948812989ad19b3cce119ff 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,6 +20,7 @@ from __future__ import print_function from .layer_function_generator import generate_layer_fn from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper +from ..framework import Variable from . import tensor from . import nn from . import ops @@ -46,6 +47,7 @@ __all__ = [ 'iou_similarity', 'box_coder', 'polygon_box_transform', + 'yolov3_loss', ] @@ -401,6 +403,113 @@ def polygon_box_transform(input, name=None): return output +@templatedoc(op_type="yolov3_loss") +def yolov3_loss(x, + gtbox, + gtlabel, + anchors, + class_num, + ignore_thresh, + loss_weight_xy=None, + loss_weight_wh=None, + loss_weight_conf_target=None, + loss_weight_conf_notarget=None, + loss_weight_class=None, + name=None): + """ + ${comment} + + Args: + x (Variable): ${x_comment} + gtbox (Variable): groud truth boxes, should be in shape of [N, B, 4], + in the third dimenstion, x, y, w, h should be stored + and x, y, w, h should be relative value of input image. + N is the batch number and B is the max box number in + an image. + gtlabel (Variable): class id of ground truth boxes, shoud be ins shape + of [N, B]. + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + ignore_thresh (float): ${ignore_thresh_comment} + loss_weight_xy (float|None): ${loss_weight_xy_comment} + loss_weight_wh (float|None): ${loss_weight_wh_comment} + loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} + loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment} + loss_weight_class (float|None): ${loss_weight_class_comment} + name (string): the name of yolov3 loss + + Returns: + Variable: A 1-D tensor with shape [1], the value of yolov3 loss + + Raises: + TypeError: Input x of yolov3_loss must be Variable + TypeError: Input gtbox of yolov3_loss must be Variable" + TypeError: Input gtlabel of yolov3_loss must be Variable" + TypeError: Attr anchors of yolov3_loss must be list or tuple + TypeError: Attr class_num of yolov3_loss must be an integer + TypeError: Attr ignore_thresh of yolov3_loss must be a float number + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') + gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 + anchors=anchors, ignore_thresh=0.5) + """ + helper = LayerHelper('yolov3_loss', **locals()) + + if not isinstance(x, Variable): + raise TypeError("Input x of yolov3_loss must be Variable") + if not isinstance(gtbox, Variable): + raise TypeError("Input gtbox of yolov3_loss must be Variable") + if not isinstance(gtlabel, Variable): + raise TypeError("Input gtlabel of yolov3_loss must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(ignore_thresh, float): + raise TypeError( + "Attr ignore_thresh of yolov3_loss must be a float number") + + if name is None: + loss = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + loss = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + attrs = { + "anchors": anchors, + "class_num": class_num, + "ignore_thresh": ignore_thresh, + } + + if loss_weight_xy is not None and isinstance(loss_weight_xy, float): + self.attrs['loss_weight_xy'] = loss_weight_xy + if loss_weight_wh is not None and isinstance(loss_weight_wh, float): + self.attrs['loss_weight_wh'] = loss_weight_wh + if loss_weight_conf_target is not None and isinstance( + loss_weight_conf_target, float): + self.attrs['loss_weight_conf_target'] = loss_weight_conf_target + if loss_weight_conf_notarget is not None and isinstance( + loss_weight_conf_notarget, float): + self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget + if loss_weight_class is not None and isinstance(loss_weight_class, float): + self.attrs['loss_weight_class'] = loss_weight_class + + helper.append_op( + type='yolov3_loss', + inputs={"X": x, + "GTBox": gtbox, + "GTLabel": gtlabel}, + outputs={'Loss': loss}, + attrs=attrs) + return loss + + @templatedoc() def detection_map(detect_res, label, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index a2eca5541a152ca99804a7f87c9b0bc3d12d4eee..d99eaa0634f93dcd16dd80ae172f11e8090a2623 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -388,5 +388,18 @@ class TestGenerateProposals(unittest.TestCase): print(rpn_rois.shape) +class TestYoloDetection(unittest.TestCase): + def test_yolov3_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') + gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') + loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, + 0.5) + + self.assertIsNotNone(loss) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..544fe4b4f81909b69a05d9751316e3d3137fdc45 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -0,0 +1,215 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def mse(x, y, num): + return ((y - x)**2).sum() / num + + +def bce(x, y, mask): + x = x.reshape((-1)) + y = y.reshape((-1)) + mask = mask.reshape((-1)) + + error_sum = 0.0 + count = 0 + for i in range(x.shape[0]): + if mask[i] > 0: + error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) + count += 1 + return error_sum / (-1.0 * count) + + +def box_iou(box1, box2): + b1_x1 = box1[0] - box1[2] / 2 + b1_x2 = box1[0] + box1[2] / 2 + b1_y1 = box1[1] - box1[3] / 2 + b1_y2 = box1[1] + box1[3] / 2 + b2_x1 = box2[0] - box2[2] / 2 + b2_x2 = box2[0] + box2[2] / 2 + b2_y1 = box2[1] - box2[3] / 2 + b2_y2 = box2[1] + box2[3] / 2 + + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + + inter_rect_x1 = max(b1_x1, b2_x1) + inter_rect_y1 = max(b1_y1, b2_y1) + inter_rect_x2 = min(b1_x2, b2_x2) + inter_rect_y2 = min(b1_y2, b2_y2) + inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max( + inter_rect_y2 - inter_rect_y1, 0) + + return inter_area / (b1_area + b2_area + inter_area) + + +def build_target(gtboxs, gtlabel, attrs, grid_size): + n, b, _ = gtboxs.shape + ignore_thresh = attrs["ignore_thresh"] + anchors = attrs["anchors"] + class_num = attrs["class_num"] + an_num = len(anchors) // 2 + obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') + tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tcls = np.zeros( + (n, an_num, grid_size, grid_size, class_num)).astype('float32') + + for i in range(n): + for j in range(b): + if gtboxs[i, j, :].sum() == 0: + continue + + gt_label = gtlabel[i, j] + gx = gtboxs[i, j, 0] * grid_size + gy = gtboxs[i, j, 1] * grid_size + gw = gtboxs[i, j, 2] * grid_size + gh = gtboxs[i, j, 3] * grid_size + + gi = int(gx) + gj = int(gy) + + gtbox = [0, 0, gw, gh] + max_iou = 0 + for k in range(an_num): + anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]] + iou = box_iou(gtbox, anchor_box) + if iou > max_iou: + max_iou = iou + best_an_index = k + if iou > ignore_thresh: + noobj_mask[i, best_an_index, gj, gi] = 0 + + obj_mask[i, best_an_index, gj, gi] = 1 + noobj_mask[i, best_an_index, gj, gi] = 0 + tx[i, best_an_index, gj, gi] = gx - gi + ty[i, best_an_index, gj, gi] = gy - gj + tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * + best_an_index]) + th[i, best_an_index, gj, gi] = np.log( + gh / anchors[2 * best_an_index + 1]) + tconf[i, best_an_index, gj, gi] = 1 + tcls[i, best_an_index, gj, gi, gt_label] = 1 + + return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + + +def YoloV3Loss(x, gtbox, gtlabel, attrs): + n, c, h, w = x.shape + an_num = len(attrs['anchors']) // 2 + class_num = attrs["class_num"] + x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + pred_x = sigmoid(x[:, :, :, :, 0]) + pred_y = sigmoid(x[:, :, :, :, 1]) + pred_w = x[:, :, :, :, 2] + pred_h = x[:, :, :, :, 3] + pred_conf = sigmoid(x[:, :, :, :, 4]) + pred_cls = sigmoid(x[:, :, :, :, 5:]) + + tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + gtbox, gtlabel, attrs, x.shape[2]) + + obj_mask_expand = np.tile( + np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) + loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) + loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) + loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) + loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) + loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) + loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask, + noobj_mask) + loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, + obj_mask_expand) + + return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + + attrs['loss_weight_wh'] * (loss_w + loss_h) \ + + attrs['loss_weight_conf_target'] * loss_conf_target \ + + attrs['loss_weight_conf_notarget'] * loss_conf_notarget \ + + attrs['loss_weight_class'] * loss_class + + +class TestYolov3LossOp(OpTest): + def setUp(self): + self.loss_weight_xy = 1.0 + self.loss_weight_wh = 1.0 + self.loss_weight_conf_target = 1.0 + self.loss_weight_conf_notarget = 1.0 + self.loss_weight_class = 1.0 + self.initTestCase() + self.op_type = 'yolov3_loss' + x = np.random.random(size=self.x_shape).astype('float32') + gtbox = np.random.random(size=self.gtbox_shape).astype('float32') + gtlabel = np.random.randint(0, self.class_num, + self.gtbox_shape[:2]).astype('int32') + + self.attrs = { + "anchors": self.anchors, + "class_num": self.class_num, + "ignore_thresh": self.ignore_thresh, + "loss_weight_xy": self.loss_weight_xy, + "loss_weight_wh": self.loss_weight_wh, + "loss_weight_conf_target": self.loss_weight_conf_target, + "loss_weight_conf_notarget": self.loss_weight_conf_notarget, + "loss_weight_class": self.loss_weight_class, + } + + self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} + self.outputs = { + 'Loss': np.array( + [YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32') + } + + def test_check_output(self): + place = core.CPUPlace() + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad_ignore_gtbox(self): + place = core.CPUPlace() + self.check_grad_with_place( + place, ['X'], + 'Loss', + no_grad_set=set(["GTBox", "GTLabel"]), + max_relative_error=0.06) + + def initTestCase(self): + self.anchors = [10, 13, 12, 12] + self.class_num = 10 + self.ignore_thresh = 0.5 + self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) + self.gtbox_shape = (5, 10, 4) + self.loss_weight_xy = 2.5 + self.loss_weight_wh = 0.8 + self.loss_weight_conf_target = 1.5 + self.loss_weight_conf_notarget = 0.5 + self.loss_weight_class = 1.2 + + +if __name__ == "__main__": + unittest.main()