From a0284f6fbcb4888e1653b7f094db615f1437943c Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 12 Nov 2018 21:13:25 +0800 Subject: [PATCH] Add backward CPU kernel. test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/yolov3_loss_op.cc | 64 ++++- paddle/fluid/operators/yolov3_loss_op.cu | 4 +- paddle/fluid/operators/yolov3_loss_op.h | 256 +++++++++++++----- python/paddle/fluid/layers/nn.py | 49 +++- .../fluid/tests/unittests/test_layers.py | 9 + .../tests/unittests/test_yolov3_loss_op.py | 42 +-- 7 files changed, 327 insertions(+), 98 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index de32a5d5a2..8344a913e9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -183,6 +183,7 @@ paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', ' paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 7369ce31e8..cf25e99505 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -20,8 +20,6 @@ using framework::Tensor; class Yolov3LossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of Yolov3LossOp should not be null."); @@ -32,7 +30,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gt = ctx->GetInputDim("GTBox"); - auto img_height = ctx->Attrs().Get("img_height"); auto anchors = ctx->Attrs().Get>("anchors"); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); @@ -43,8 +40,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5"); - PADDLE_ENFORCE_GT(img_height, 0, - "Attr(img_height) value should be greater then 0"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -87,13 +82,43 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("anchors", "The anchor width and height, " "it will be parsed pair by pair."); - AddAttr("img_height", - "The input image height after crop of yolov3 network."); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, specify the grid size, each grid point predict given + number boxes, this given number is specified by anchors, it should be + half anchors length, which following will be represented as S. In the + second dimention(the channel dimention), C should be S * (class_num + 5), + class_num is the box categoriy number of source dataset(such as coco), + so in the second dimention, stores 4 box location coordinates x, y, w, h + and confidence score of the box and class one-hot key of each anchor box. + + While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions + correspnd to: + + $$ + b_x = \sigma(t_x) + c_x + b_y = \sigma(t_y) + c_y + b_w = p_w e^{t_w} + b_h = p_h e^{t_h} + $$ + + While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$ + is specified by anchors. + + As for confidence score, it is the logistic regression value of IoU between + anchor boxes and ground truth boxes, the score of the anchor box which has + the max IoU should be 1, and if the anchor box has IoU bigger then ignore + thresh, the confidence score loss of this anchor box will be ignored. + + Therefore, the yolov3 loss consist of three major parts, box location loss, + confidence score loss, and classification loss. The MSE loss is used for + box location, and binary cross entropy loss is used for confidence score + loss and classification loss. )DOC"); } }; @@ -101,8 +126,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { class Yolov3LossOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), @@ -113,6 +136,7 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { } } + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -120,12 +144,32 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { } }; +class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("yolov3_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("GTBox", Input("GTBox")); + op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("GTBox"), {}); + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); REGISTER_OP_CPU_KERNEL( yolov3_loss, diff --git a/paddle/fluid/operators/yolov3_loss_op.cu b/paddle/fluid/operators/yolov3_loss_op.cu index 48f997456a..f901b10d38 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cu +++ b/paddle/fluid/operators/yolov3_loss_op.cu @@ -17,7 +17,7 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( yolov3_loss, - ops::Yolov3LossOpKernel); + ops::Yolov3LossKernel); REGISTER_OP_CUDA_KERNEL( yolov3_loss_grad, - ops::Yolov3LossGradOpKernel); + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 426e0688ab..a2ed4440a7 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,10 +33,22 @@ static inline bool isZero(T x) { } template -static inline T sigmod(T x) { +static inline T sigmoid(T x) { return 1.0 / (exp(-1.0 * x) + 1.0); } +template +static inline T CalcMaskPointNum(const Tensor& mask) { + auto mask_t = EigenVector::Flatten(mask); + T count = 0.0; + for (int i = 0; i < mask_t.dimensions()[0]; i++) { + if (mask_t(i)) { + count += 1.0; + } + } + return count; +} + template static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { @@ -55,6 +67,21 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, return (error_sum / points); } +template +static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, + const Tensor& mask, T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; + } + } +} + template static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { @@ -75,21 +102,34 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, } template -static void CalcPredResult(const Tensor& input, Tensor* pred_confs, - Tensor* pred_classes, Tensor* pred_x, Tensor* pred_y, - Tensor* pred_w, Tensor* pred_h, - std::vector anchors, const int class_num, - const int stride) { +static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& mask, + T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; + } + } +} + +template +static void CalcPredResult(const Tensor& input, Tensor* pred_conf, + Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, const int anchor_num, + const int class_num) { const int n = input.dims()[0]; - const int c = input.dims()[1]; const int h = input.dims()[2]; const int w = input.dims()[3]; - const int anchor_num = anchors.size() / 2; const int box_attr_num = 5 + class_num; auto input_t = EigenTensor::From(input); - auto pred_confs_t = EigenTensor::From(*pred_confs); - auto pred_classes_t = EigenTensor::From(*pred_classes); + auto pred_conf_t = EigenTensor::From(*pred_conf); + auto pred_class_t = EigenTensor::From(*pred_class); auto pred_x_t = EigenTensor::From(*pred_x); auto pred_y_t = EigenTensor::From(*pred_y); auto pred_w_t = EigenTensor::From(*pred_w); @@ -97,26 +137,23 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs, for (int i = 0; i < n; i++) { for (int an_idx = 0; an_idx < anchor_num; an_idx++) { - float an_w = anchors[an_idx * 2] / stride; - float an_h = anchors[an_idx * 2 + 1] / stride; - for (int j = 0; j < h; j++) { for (int k = 0; k < w; k++) { pred_x_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx, j, k)); + sigmoid(input_t(i, box_attr_num * an_idx, j, k)); pred_y_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx + 1, j, k)); + sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); pred_w_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 2, j, k); pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); - pred_confs_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); + pred_conf_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); for (int c = 0; c < class_num; c++) { - pred_classes_t(i, an_idx, j, k, c) = - sigmod(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + pred_class_t(i, an_idx, j, k, c) = + sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); } } } @@ -148,27 +185,11 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { return inter_area / (b1_area + b2_area - inter_area); } -template -static inline int GetPredLabel(const Tensor& pred_classes, int n, - int best_an_index, int gj, int gi) { - auto pred_classes_t = EigenTensor::From(pred_classes); - T score = 0.0; - int label = -1; - for (int i = 0; i < pred_classes.dims()[4]; i++) { - if (pred_classes_t(n, best_an_index, gj, gi, i) > score) { - score = pred_classes_t(n, best_an_index, gj, gi, i); - label = i; - } - } - return label; -} - template static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, - std::vector anchors, const int img_height, - const int grid_size, Tensor* obj_mask, - Tensor* noobj_mask, Tensor* tx, Tensor* ty, - Tensor* tw, Tensor* th, Tensor* tconf, + std::vector anchors, const int grid_size, + Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf, Tensor* tclass) { const int n = gt_boxes.dims()[0]; const int b = gt_boxes.dims()[1]; @@ -240,6 +261,61 @@ static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, .broadcast(Array5(1, 1, 1, 1, class_num)); } +template +static void AddAllGradToInputGrad( + Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, + const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, + const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, + const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj, + const Tensor& grad_class, const int class_num) { + const int n = pred_x.dims()[0]; + const int an_num = pred_x.dims()[1]; + const int h = pred_x.dims()[2]; + const int w = pred_x.dims()[3]; + const int attr_num = class_num + 5; + auto grad_t = EigenTensor::From(*grad).setConstant(0.0); + auto pred_x_t = EigenTensor::From(pred_x); + auto pred_y_t = EigenTensor::From(pred_y); + auto pred_conf_t = EigenTensor::From(pred_conf); + auto pred_class_t = EigenTensor::From(pred_class); + auto grad_x_t = EigenTensor::From(grad_x); + auto grad_y_t = EigenTensor::From(grad_y); + auto grad_w_t = EigenTensor::From(grad_w); + auto grad_h_t = EigenTensor::From(grad_h); + auto grad_conf_obj_t = EigenTensor::From(grad_conf_obj); + auto grad_conf_noobj_t = EigenTensor::From(grad_conf_noobj); + auto grad_class_t = EigenTensor::From(grad_class); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * + pred_x_t(i, j, k, l) * + (1.0 - pred_x_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num + 1, k, l) = + grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * + (1.0 - pred_y_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss; + grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss; + grad_t(i, j * attr_num + 4, k, l) = + grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num + 4, k, l) += + grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss; + + for (int c = 0; c < class_num; c++) { + grad_t(i, j * attr_num + 5 + c, k, l) = + grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * + (1.0 - pred_class_t(i, j, k, l, c)) * loss; + } + } + } + } + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -247,28 +323,25 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_boxes = ctx.Input("GTBox"); auto* loss = ctx.Output("Loss"); - int img_height = ctx.Attr("img_height"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); const int n = input->dims()[0]; - const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - const T stride = static_cast(img_height) / h; Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_confs, pred_classes; + Tensor pred_conf, pred_class; pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_confs.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_classes.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_confs, &pred_classes, &pred_x, &pred_y, - &pred_w, &pred_h, anchors, class_num, stride); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tconf, tclass; @@ -280,9 +353,8 @@ class Yolov3LossKernel : public framework::OpKernel { th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, img_height, h, - &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, - &tclass); + PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, @@ -293,17 +365,9 @@ class Yolov3LossKernel : public framework::OpKernel { T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_obj = CalcBCEWithMask(pred_confs, tconf, obj_mask); - T loss_conf_noobj = CalcBCEWithMask(pred_confs, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask_expand); - - // LOG(ERROR) << "loss_x: " << loss_x; - // LOG(ERROR) << "loss_y: " << loss_y; - // LOG(ERROR) << "loss_w: " << loss_w; - // LOG(ERROR) << "loss_h: " << loss_h; - // LOG(ERROR) << "loss_conf_obj: " << loss_conf_obj; - // LOG(ERROR) << "loss_conf_noobj: " << loss_conf_noobj; - // LOG(ERROR) << "loss_class: " << loss_class; + T loss_conf_obj = CalcBCEWithMask(pred_conf, tconf, obj_mask); + T loss_conf_noobj = CalcBCEWithMask(pred_conf, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj + @@ -315,8 +379,76 @@ template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); + auto* input = ctx.Input("X"); + auto* gt_boxes = ctx.Input("GTBox"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Loss")); + const T loss = output_grad->data()[0]; + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_conf, pred_class; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); + + Tensor obj_mask, noobj_mask; + Tensor tx, ty, tw, th, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + + Tensor grad_x, grad_y, grad_w, grad_h; + Tensor grad_conf_obj, grad_conf_noobj, grad_class; + grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_obj.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_noobj.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + T obj_mf = CalcMaskPointNum(obj_mask); + T noobj_mf = CalcMaskPointNum(noobj_mask); + T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); + CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); + CalcBCEGradWithMask(&grad_conf_obj, pred_conf, tconf, obj_mask, obj_mf); + CalcBCEGradWithMask(&grad_conf_noobj, pred_conf, tconf, noobj_mask, + noobj_mf); + CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, + obj_expand_mf); + + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + AddAllGradToInputGrad( + input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, + grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1ee7198f29..a4efb16682 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8244,14 +8244,55 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss -def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None): +@templatedoc(op_type="yolov3_loss") +def yolov3_loss(x, gtbox, anchors, class_num, ignore_thresh, name=None): """ - **YOLOv3 Loss Layer** + ${comment} + + Args: + x (Variable): ${x_comment} + gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5], + in the third dimenstion, class_id, x, y, w, h should + be stored and x, y, w, h should be relative valud of + input image. + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + ignore_thresh (float): ${ignore_thresh_comment} + name (string): the name of yolov3 loss - This layer + Returns: + Variable: A 1-D tensor with shape [1], the value of yolov3 loss + + Raises: + TypeError: Input x of yolov3_loss must be Variable + TypeError: Input gtbox of yolov3_loss must be Variable" + TypeError: Attr anchors of yolov3_loss must be list or tuple + TypeError: Attr class_num of yolov3_loss must be an integer + TypeError: Attr ignore_thresh of yolov3_loss must be a float number + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 + anchors=anchors, ignore_thresh=0.5) """ helper = LayerHelper('yolov3_loss', **locals()) + if not isinstance(x, Variable): + raise TypeError("Input x of yolov3_loss must be Variable") + if not isinstance(gtbox, Variable): + raise TypeError("Input gtbox of yolov3_loss must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(ignore_thresh, float): + raise TypeError( + "Attr ignore_thresh of yolov3_loss must be a float number") + if name is None: loss = helper.create_variable_for_type_inference(dtype=x.dtype) else: @@ -8264,8 +8305,8 @@ def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None): "GTBox": gtbox}, outputs={'Loss': loss}, attrs={ - "img_height": img_height, "anchors": anchors, + "class_num": class_num, "ignore_thresh": ignore_thresh, }) return loss diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f48d9c84f9..dd02968c30 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -911,6 +911,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(data_1) print(str(program)) + def test_yolov3_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32') + loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5) + + self.assertIsNotNone(loss) + def test_bilinear_tensor_product_layer(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index f5b15efb27..4562f8bd49 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import division + import unittest import numpy as np from op_test import OpTest +from paddle.fluid import core + def sigmoid(x): return 1.0 / (1.0 + np.exp(-1.0 * x)) @@ -65,10 +69,9 @@ def box_iou(box1, box2): def build_target(gtboxs, attrs, grid_size): n, b, _ = gtboxs.shape ignore_thresh = attrs["ignore_thresh"] - img_height = attrs["img_height"] anchors = attrs["anchors"] class_num = attrs["class_num"] - an_num = len(anchors) / 2 + an_num = len(anchors) // 2 obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') @@ -120,7 +123,7 @@ def build_target(gtboxs, attrs, grid_size): def YoloV3Loss(x, gtbox, attrs): n, c, h, w = x.shape - an_num = len(attrs['anchors']) / 2 + an_num = len(attrs['anchors']) // 2 class_num = attrs["class_num"] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) pred_x = sigmoid(x[:, :, :, :, 0]) @@ -144,13 +147,6 @@ def YoloV3Loss(x, gtbox, attrs): noobj_mask) loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, obj_mask_expand) - # print "loss_x: ", loss_x - # print "loss_y: ", loss_y - # print "loss_w: ", loss_w - # print "loss_h: ", loss_h - # print "loss_conf_obj: ", loss_conf_obj - # print "loss_conf_noobj: ", loss_conf_noobj - # print "loss_class: ", loss_class return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class @@ -165,29 +161,35 @@ class TestYolov3LossOp(OpTest): self.gtbox_shape[:2]) self.attrs = { - "img_height": self.img_height, "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, } self.inputs = {'X': x, 'GTBox': gtbox} - self.outputs = {'Loss': np.array([YoloV3Loss(x, gtbox, self.attrs)])} - print self.outputs + self.outputs = { + 'Loss': + np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32') + } def test_check_output(self): - self.check_output(atol=1e-3) + place = core.CPUPlace() + self.check_output_with_place(place, atol=1e-3) - # def test_check_grad_normal(self): - # self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61) + def test_check_grad_ignore_gtbox(self): + place = core.CPUPlace() + self.check_grad_with_place( + place, ['X'], + 'Loss', + no_grad_set=set("GTBox"), + max_relative_error=0.1) def initTestCase(self): - self.img_height = 608 - self.anchors = [10, 13, 16, 30, 33, 23] + self.anchors = [10, 13, 12, 12] self.class_num = 10 self.ignore_thresh = 0.5 - self.x_shape = (5, len(self.anchors) / 2 * (5 + self.class_num), 7, 7) - self.gtbox_shape = (5, 10, 5) + self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) + self.gtbox_shape = (5, 5, 5) if __name__ == "__main__": -- GitLab