diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index de32a5d5a297b63a80aa41fc99fcacf60bbf2488..8344a913e9bf818ab7608f181396a3e016b70351 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -183,6 +183,7 @@ paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', '
 paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
 paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
+paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
 paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc
index 7369ce31e8c9eb72905093760ff6a56c92c5afc2..cf25e995054dce16e6f95540be385c0a09ef9ee2 100644
--- a/paddle/fluid/operators/yolov3_loss_op.cc
+++ b/paddle/fluid/operators/yolov3_loss_op.cc
@@ -20,8 +20,6 @@ using framework::Tensor;
 class Yolov3LossOp : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
-
- protected:
   void InferShape(framework::InferShapeContext* ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("X"),
                    "Input(X) of Yolov3LossOp should not be null.");
@@ -32,7 +30,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
 
     auto dim_x = ctx->GetInputDim("X");
     auto dim_gt = ctx->GetInputDim("GTBox");
-    auto img_height = ctx->Attrs().Get<int>("img_height");
     auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
     auto class_num = ctx->Attrs().Get<int>("class_num");
     PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
@@ -43,8 +40,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
                       "+ class_num)).");
     PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor");
     PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5");
-    PADDLE_ENFORCE_GT(img_height, 0,
-                      "Attr(img_height) value should be greater then 0");
     PADDLE_ENFORCE_GT(anchors.size(), 0,
                       "Attr(anchors) length should be greater then 0.");
     PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
@@ -87,13 +82,43 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
     AddAttr<std::vector<int>>("anchors",
                               "The anchor width and height, "
                               "it will be parsed pair by pair.");
-    AddAttr<int>("img_height",
-                 "The input image height after crop of yolov3 network.");
     AddAttr<float>("ignore_thresh",
                    "The ignore threshold to ignore confidence loss.");
     AddComment(R"DOC(
          This operator generate yolov3 loss by given predict result and ground
          truth boxes.
+         
+         The output of previous network is in shape [N, C, H, W], while H and W
+         should be the same, specify the grid size, each grid point predict given
+         number boxes, this given number is specified by anchors, it should be 
+         half anchors length, which following will be represented as S. In the 
+         second dimention(the channel dimention), C should be S * (class_num + 5),
+         class_num is the box categoriy number of source dataset(such as coco), 
+         so in the second dimention, stores 4 box location coordinates x, y, w, h 
+         and confidence score of the box and class one-hot key of each anchor box.
+
+         While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions
+         correspnd to:
+
+         $$
+         b_x = \sigma(t_x) + c_x
+         b_y = \sigma(t_y) + c_y
+         b_w = p_w e^{t_w}
+         b_h = p_h e^{t_h}
+         $$
+
+         While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$
+         is specified by anchors.
+
+         As for confidence score, it is the logistic regression value of IoU between
+         anchor boxes and ground truth boxes, the score of the anchor box which has 
+         the max IoU should be 1, and if the anchor box has IoU bigger then ignore 
+         thresh, the confidence score loss of this anchor box will be ignored.
+
+         Therefore, the yolov3 loss consist of three major parts, box location loss,
+         confidence score loss, and classification loss. The MSE loss is used for 
+         box location, and binary cross entropy loss is used for confidence score 
+         loss and classification loss.
          )DOC");
   }
 };
@@ -101,8 +126,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
 class Yolov3LossOpGrad : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
-
- protected:
   void InferShape(framework::InferShapeContext* ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
     PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
@@ -113,6 +136,7 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
     }
   }
 
+ protected:
   framework::OpKernelType GetExpectedKernelType(
       const framework::ExecutionContext& ctx) const override {
     return framework::OpKernelType(
@@ -120,12 +144,32 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
   }
 };
 
+class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+  std::unique_ptr<framework::OpDesc> Apply() const override {
+    auto* op = new framework::OpDesc();
+    op->SetType("yolov3_loss_grad");
+    op->SetInput("X", Input("X"));
+    op->SetInput("GTBox", Input("GTBox"));
+    op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
+
+    op->SetAttrMap(Attrs());
+
+    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
+    op->SetOutput(framework::GradVarName("GTBox"), {});
+    return std::unique_ptr<framework::OpDesc>(op);
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 namespace ops = paddle::operators;
 REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
-                  paddle::framework::DefaultGradOpDescMaker<true>);
+                  ops::Yolov3LossGradMaker);
 REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
 REGISTER_OP_CPU_KERNEL(
     yolov3_loss,
diff --git a/paddle/fluid/operators/yolov3_loss_op.cu b/paddle/fluid/operators/yolov3_loss_op.cu
index 48f997456ac4885f7bdc42d64ff671d0a9998baa..f901b10d38e486be874f75f39ef4afb4cb6a560a 100644
--- a/paddle/fluid/operators/yolov3_loss_op.cu
+++ b/paddle/fluid/operators/yolov3_loss_op.cu
@@ -17,7 +17,7 @@
 namespace ops = paddle::operators;
 REGISTER_OP_CUDA_KERNEL(
     yolov3_loss,
-    ops::Yolov3LossOpKernel<paddle::platform::CUDADeviceContext, float>);
+    ops::Yolov3LossKernel<paddle::platform::CUDADeviceContext, float>);
 REGISTER_OP_CUDA_KERNEL(
     yolov3_loss_grad,
-    ops::Yolov3LossGradOpKernel<paddle::platform::CUDADeviceContext, float>);
+    ops::Yolov3LossGradKernel<paddle::platform::CUDADeviceContext, float>);
diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h
index 426e0688ab63aaf9bc978b70b35bfe5afd0c8d68..a2ed4440a74fe3dc48d8b82f5dbb4aa65445b0ae 100644
--- a/paddle/fluid/operators/yolov3_loss_op.h
+++ b/paddle/fluid/operators/yolov3_loss_op.h
@@ -33,10 +33,22 @@ static inline bool isZero(T x) {
 }
 
 template <typename T>
-static inline T sigmod(T x) {
+static inline T sigmoid(T x) {
   return 1.0 / (exp(-1.0 * x) + 1.0);
 }
 
+template <typename T>
+static inline T CalcMaskPointNum(const Tensor& mask) {
+  auto mask_t = EigenVector<int>::Flatten(mask);
+  T count = 0.0;
+  for (int i = 0; i < mask_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      count += 1.0;
+    }
+  }
+  return count;
+}
+
 template <typename T>
 static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y,
                                 const Tensor& mask) {
@@ -55,6 +67,21 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y,
   return (error_sum / points);
 }
 
+template <typename T>
+static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y,
+                                const Tensor& mask, T mf) {
+  auto grad_t = EigenVector<T>::Flatten(*grad).setConstant(0.0);
+  auto x_t = EigenVector<T>::Flatten(x);
+  auto y_t = EigenVector<T>::Flatten(y);
+  auto mask_t = EigenVector<int>::Flatten(mask);
+
+  for (int i = 0; i < x_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf;
+    }
+  }
+}
+
 template <typename T>
 static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y,
                                 const Tensor& mask) {
@@ -75,21 +102,34 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y,
 }
 
 template <typename T>
-static void CalcPredResult(const Tensor& input, Tensor* pred_confs,
-                           Tensor* pred_classes, Tensor* pred_x, Tensor* pred_y,
-                           Tensor* pred_w, Tensor* pred_h,
-                           std::vector<int> anchors, const int class_num,
-                           const int stride) {
+static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x,
+                                       const Tensor& y, const Tensor& mask,
+                                       T mf) {
+  auto grad_t = EigenVector<T>::Flatten(*grad).setConstant(0.0);
+  auto x_t = EigenVector<T>::Flatten(x);
+  auto y_t = EigenVector<T>::Flatten(y);
+  auto mask_t = EigenVector<int>::Flatten(mask);
+
+  for (int i = 0; i < x_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf;
+    }
+  }
+}
+
+template <typename T>
+static void CalcPredResult(const Tensor& input, Tensor* pred_conf,
+                           Tensor* pred_class, Tensor* pred_x, Tensor* pred_y,
+                           Tensor* pred_w, Tensor* pred_h, const int anchor_num,
+                           const int class_num) {
   const int n = input.dims()[0];
-  const int c = input.dims()[1];
   const int h = input.dims()[2];
   const int w = input.dims()[3];
-  const int anchor_num = anchors.size() / 2;
   const int box_attr_num = 5 + class_num;
 
   auto input_t = EigenTensor<T, 4>::From(input);
-  auto pred_confs_t = EigenTensor<T, 4>::From(*pred_confs);
-  auto pred_classes_t = EigenTensor<T, 5>::From(*pred_classes);
+  auto pred_conf_t = EigenTensor<T, 4>::From(*pred_conf);
+  auto pred_class_t = EigenTensor<T, 5>::From(*pred_class);
   auto pred_x_t = EigenTensor<T, 4>::From(*pred_x);
   auto pred_y_t = EigenTensor<T, 4>::From(*pred_y);
   auto pred_w_t = EigenTensor<T, 4>::From(*pred_w);
@@ -97,26 +137,23 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs,
 
   for (int i = 0; i < n; i++) {
     for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
-      float an_w = anchors[an_idx * 2] / stride;
-      float an_h = anchors[an_idx * 2 + 1] / stride;
-
       for (int j = 0; j < h; j++) {
         for (int k = 0; k < w; k++) {
           pred_x_t(i, an_idx, j, k) =
-              sigmod(input_t(i, box_attr_num * an_idx, j, k));
+              sigmoid(input_t(i, box_attr_num * an_idx, j, k));
           pred_y_t(i, an_idx, j, k) =
-              sigmod(input_t(i, box_attr_num * an_idx + 1, j, k));
+              sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k));
           pred_w_t(i, an_idx, j, k) =
               input_t(i, box_attr_num * an_idx + 2, j, k);
           pred_h_t(i, an_idx, j, k) =
               input_t(i, box_attr_num * an_idx + 3, j, k);
 
-          pred_confs_t(i, an_idx, j, k) =
-              sigmod(input_t(i, box_attr_num * an_idx + 4, j, k));
+          pred_conf_t(i, an_idx, j, k) =
+              sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k));
 
           for (int c = 0; c < class_num; c++) {
-            pred_classes_t(i, an_idx, j, k, c) =
-                sigmod(input_t(i, box_attr_num * an_idx + 5 + c, j, k));
+            pred_class_t(i, an_idx, j, k, c) =
+                sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k));
           }
         }
       }
@@ -148,27 +185,11 @@ static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
   return inter_area / (b1_area + b2_area - inter_area);
 }
 
-template <typename T>
-static inline int GetPredLabel(const Tensor& pred_classes, int n,
-                               int best_an_index, int gj, int gi) {
-  auto pred_classes_t = EigenTensor<T, 5>::From(pred_classes);
-  T score = 0.0;
-  int label = -1;
-  for (int i = 0; i < pred_classes.dims()[4]; i++) {
-    if (pred_classes_t(n, best_an_index, gj, gi, i) > score) {
-      score = pred_classes_t(n, best_an_index, gj, gi, i);
-      label = i;
-    }
-  }
-  return label;
-}
-
 template <typename T>
 static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh,
-                            std::vector<int> anchors, const int img_height,
-                            const int grid_size, Tensor* obj_mask,
-                            Tensor* noobj_mask, Tensor* tx, Tensor* ty,
-                            Tensor* tw, Tensor* th, Tensor* tconf,
+                            std::vector<int> anchors, const int grid_size,
+                            Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx,
+                            Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf,
                             Tensor* tclass) {
   const int n = gt_boxes.dims()[0];
   const int b = gt_boxes.dims()[1];
@@ -240,6 +261,61 @@ static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand,
                           .broadcast(Array5(1, 1, 1, 1, class_num));
 }
 
+template <typename T>
+static void AddAllGradToInputGrad(
+    Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y,
+    const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x,
+    const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h,
+    const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj,
+    const Tensor& grad_class, const int class_num) {
+  const int n = pred_x.dims()[0];
+  const int an_num = pred_x.dims()[1];
+  const int h = pred_x.dims()[2];
+  const int w = pred_x.dims()[3];
+  const int attr_num = class_num + 5;
+  auto grad_t = EigenTensor<T, 4>::From(*grad).setConstant(0.0);
+  auto pred_x_t = EigenTensor<T, 4>::From(pred_x);
+  auto pred_y_t = EigenTensor<T, 4>::From(pred_y);
+  auto pred_conf_t = EigenTensor<T, 4>::From(pred_conf);
+  auto pred_class_t = EigenTensor<T, 5>::From(pred_class);
+  auto grad_x_t = EigenTensor<T, 4>::From(grad_x);
+  auto grad_y_t = EigenTensor<T, 4>::From(grad_y);
+  auto grad_w_t = EigenTensor<T, 4>::From(grad_w);
+  auto grad_h_t = EigenTensor<T, 4>::From(grad_h);
+  auto grad_conf_obj_t = EigenTensor<T, 4>::From(grad_conf_obj);
+  auto grad_conf_noobj_t = EigenTensor<T, 4>::From(grad_conf_noobj);
+  auto grad_class_t = EigenTensor<T, 5>::From(grad_class);
+
+  for (int i = 0; i < n; i++) {
+    for (int j = 0; j < an_num; j++) {
+      for (int k = 0; k < h; k++) {
+        for (int l = 0; l < w; l++) {
+          grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) *
+                                          pred_x_t(i, j, k, l) *
+                                          (1.0 - pred_x_t(i, j, k, l)) * loss;
+          grad_t(i, j * attr_num + 1, k, l) =
+              grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) *
+              (1.0 - pred_y_t(i, j, k, l)) * loss;
+          grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss;
+          grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss;
+          grad_t(i, j * attr_num + 4, k, l) =
+              grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
+              (1.0 - pred_conf_t(i, j, k, l)) * loss;
+          grad_t(i, j * attr_num + 4, k, l) +=
+              grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
+              (1.0 - pred_conf_t(i, j, k, l)) * loss;
+
+          for (int c = 0; c < class_num; c++) {
+            grad_t(i, j * attr_num + 5 + c, k, l) =
+                grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) *
+                (1.0 - pred_class_t(i, j, k, l, c)) * loss;
+          }
+        }
+      }
+    }
+  }
+}
+
 template <typename DeviceContext, typename T>
 class Yolov3LossKernel : public framework::OpKernel<T> {
  public:
@@ -247,28 +323,25 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
     auto* input = ctx.Input<Tensor>("X");
     auto* gt_boxes = ctx.Input<Tensor>("GTBox");
     auto* loss = ctx.Output<Tensor>("Loss");
-    int img_height = ctx.Attr<int>("img_height");
     auto anchors = ctx.Attr<std::vector<int>>("anchors");
     int class_num = ctx.Attr<int>("class_num");
     float ignore_thresh = ctx.Attr<float>("ignore_thresh");
 
     const int n = input->dims()[0];
-    const int c = input->dims()[1];
     const int h = input->dims()[2];
     const int w = input->dims()[3];
     const int an_num = anchors.size() / 2;
-    const T stride = static_cast<T>(img_height) / h;
 
     Tensor pred_x, pred_y, pred_w, pred_h;
-    Tensor pred_confs, pred_classes;
+    Tensor pred_conf, pred_class;
     pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
-    pred_confs.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
-    pred_classes.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
-    CalcPredResult<T>(*input, &pred_confs, &pred_classes, &pred_x, &pred_y,
-                      &pred_w, &pred_h, anchors, class_num, stride);
+    pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    CalcPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
+                      &pred_w, &pred_h, an_num, class_num);
 
     Tensor obj_mask, noobj_mask;
     Tensor tx, ty, tw, th, tconf, tclass;
@@ -280,9 +353,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
     th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
     tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
-    PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, img_height, h,
-                       &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf,
-                       &tclass);
+    PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, h, &obj_mask,
+                       &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
 
     Tensor obj_mask_expand;
     obj_mask_expand.mutable_data<int>({n, an_num, h, w, class_num},
@@ -293,17 +365,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
     T loss_y = CalcMSEWithMask<T>(pred_y, ty, obj_mask);
     T loss_w = CalcMSEWithMask<T>(pred_w, tw, obj_mask);
     T loss_h = CalcMSEWithMask<T>(pred_h, th, obj_mask);
-    T loss_conf_obj = CalcBCEWithMask<T>(pred_confs, tconf, obj_mask);
-    T loss_conf_noobj = CalcBCEWithMask<T>(pred_confs, tconf, noobj_mask);
-    T loss_class = CalcBCEWithMask<T>(pred_classes, tclass, obj_mask_expand);
-
-    // LOG(ERROR) << "loss_x: " << loss_x;
-    // LOG(ERROR) << "loss_y: " << loss_y;
-    // LOG(ERROR) << "loss_w: " << loss_w;
-    // LOG(ERROR) << "loss_h: " << loss_h;
-    // LOG(ERROR) << "loss_conf_obj: " << loss_conf_obj;
-    // LOG(ERROR) << "loss_conf_noobj: " << loss_conf_noobj;
-    // LOG(ERROR) << "loss_class: " << loss_class;
+    T loss_conf_obj = CalcBCEWithMask<T>(pred_conf, tconf, obj_mask);
+    T loss_conf_noobj = CalcBCEWithMask<T>(pred_conf, tconf, noobj_mask);
+    T loss_class = CalcBCEWithMask<T>(pred_class, tclass, obj_mask_expand);
 
     auto* loss_data = loss->mutable_data<T>({1}, ctx.GetPlace());
     loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj +
@@ -315,8 +379,76 @@ template <typename DeviceContext, typename T>
 class Yolov3LossGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
-    auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
+    auto* input = ctx.Input<Tensor>("X");
+    auto* gt_boxes = ctx.Input<Tensor>("GTBox");
+    auto anchors = ctx.Attr<std::vector<int>>("anchors");
+    int class_num = ctx.Attr<int>("class_num");
+    float ignore_thresh = ctx.Attr<float>("ignore_thresh");
+    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
+    const T loss = output_grad->data<T>()[0];
+
+    const int n = input->dims()[0];
+    const int c = input->dims()[1];
+    const int h = input->dims()[2];
+    const int w = input->dims()[3];
+    const int an_num = anchors.size() / 2;
+
+    Tensor pred_x, pred_y, pred_w, pred_h;
+    Tensor pred_conf, pred_class;
+    pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    CalcPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
+                      &pred_w, &pred_h, an_num, class_num);
+
+    Tensor obj_mask, noobj_mask;
+    Tensor tx, ty, tw, th, tconf, tclass;
+    obj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
+    noobj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
+    tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, h, &obj_mask,
+                       &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
+
+    Tensor obj_mask_expand;
+    obj_mask_expand.mutable_data<int>({n, an_num, h, w, class_num},
+                                      ctx.GetPlace());
+    ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask);
+
+    Tensor grad_x, grad_y, grad_w, grad_h;
+    Tensor grad_conf_obj, grad_conf_noobj, grad_class;
+    grad_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_conf_obj.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_conf_noobj.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    T obj_mf = CalcMaskPointNum<int>(obj_mask);
+    T noobj_mf = CalcMaskPointNum<int>(noobj_mask);
+    T obj_expand_mf = CalcMaskPointNum<int>(obj_mask_expand);
+    CalcMSEGradWithMask<T>(&grad_x, pred_x, tx, obj_mask, obj_mf);
+    CalcMSEGradWithMask<T>(&grad_y, pred_y, ty, obj_mask, obj_mf);
+    CalcMSEGradWithMask<T>(&grad_w, pred_w, tw, obj_mask, obj_mf);
+    CalcMSEGradWithMask<T>(&grad_h, pred_h, th, obj_mask, obj_mf);
+    CalcBCEGradWithMask<T>(&grad_conf_obj, pred_conf, tconf, obj_mask, obj_mf);
+    CalcBCEGradWithMask<T>(&grad_conf_noobj, pred_conf, tconf, noobj_mask,
+                           noobj_mf);
+    CalcBCEGradWithMask<T>(&grad_class, pred_class, tclass, obj_mask_expand,
+                           obj_expand_mf);
+
+    input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
+    AddAllGradToInputGrad<T>(
+        input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y,
+        grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num);
   }
 };
 
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index 1ee7198f29260f55b7fe6ad949ee4b5218b4b938..a4efb166826d53e7e4d0a39173ff07d5ae9fcf8d 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -8244,14 +8244,55 @@ def log_loss(input, label, epsilon=1e-4, name=None):
     return loss
 
 
-def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None):
+@templatedoc(op_type="yolov3_loss")
+def yolov3_loss(x, gtbox, anchors, class_num, ignore_thresh, name=None):
     """
-    **YOLOv3 Loss Layer**
+    ${comment}
+
+    Args:
+        x (Variable): ${x_comment}
+        gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5],
+                          in the third dimenstion, class_id, x, y, w, h should
+                          be stored and x, y, w, h should be relative valud of
+                          input image.
+        anchors (list|tuple): ${anchors_comment}
+        class_num (int): ${class_num_comment}
+        ignore_thresh (float): ${ignore_thresh_comment}
+        name (string): the name of yolov3 loss
 
-    This layer 
+    Returns:
+        Variable: A 1-D tensor with shape [1], the value of yolov3 loss
+
+    Raises:
+        TypeError: Input x of yolov3_loss must be Variable
+        TypeError: Input gtbox of yolov3_loss must be Variable"
+        TypeError: Attr anchors of yolov3_loss must be list or tuple
+        TypeError: Attr class_num of yolov3_loss must be an integer
+        TypeError: Attr ignore_thresh of yolov3_loss must be a float number
+
+    Examples:
+    .. code-block:: python
+
+        x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32')
+        gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32')
+        anchors = [10, 13, 16, 30, 33, 23]
+        loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80
+                                        anchors=anchors, ignore_thresh=0.5)
     """
     helper = LayerHelper('yolov3_loss', **locals())
 
+    if not isinstance(x, Variable):
+        raise TypeError("Input x of yolov3_loss must be Variable")
+    if not isinstance(gtbox, Variable):
+        raise TypeError("Input gtbox of yolov3_loss must be Variable")
+    if not isinstance(anchors, list) and not isinstance(anchors, tuple):
+        raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
+    if not isinstance(class_num, int):
+        raise TypeError("Attr class_num of yolov3_loss must be an integer")
+    if not isinstance(ignore_thresh, float):
+        raise TypeError(
+            "Attr ignore_thresh of yolov3_loss must be a float number")
+
     if name is None:
         loss = helper.create_variable_for_type_inference(dtype=x.dtype)
     else:
@@ -8264,8 +8305,8 @@ def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None):
                 "GTBox": gtbox},
         outputs={'Loss': loss},
         attrs={
-            "img_height": img_height,
             "anchors": anchors,
+            "class_num": class_num,
             "ignore_thresh": ignore_thresh,
         })
     return loss
diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py
index f48d9c84f9c10b0eff8e41a510d168543c9795fa..dd02968c30fc3eb289d63c9dba581c7a4966d0d4 100644
--- a/python/paddle/fluid/tests/unittests/test_layers.py
+++ b/python/paddle/fluid/tests/unittests/test_layers.py
@@ -911,6 +911,15 @@ class TestBook(unittest.TestCase):
             self.assertIsNotNone(data_1)
         print(str(program))
 
+    def test_yolov3_loss(self):
+        program = Program()
+        with program_guard(program):
+            x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
+            gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32')
+            loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5)
+
+            self.assertIsNotNone(loss)
+
     def test_bilinear_tensor_product_layer(self):
         program = Program()
         with program_guard(program):
diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
index f5b15efb27ff7a50ce161130cebb1fb854d8b4db..4562f8bd4962e1ab3d3fb30dc7125a1bec66ac12 100644
--- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
+++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
@@ -12,10 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from __future__ import division
+
 import unittest
 import numpy as np
 from op_test import OpTest
 
+from paddle.fluid import core
+
 
 def sigmoid(x):
     return 1.0 / (1.0 + np.exp(-1.0 * x))
@@ -65,10 +69,9 @@ def box_iou(box1, box2):
 def build_target(gtboxs, attrs, grid_size):
     n, b, _ = gtboxs.shape
     ignore_thresh = attrs["ignore_thresh"]
-    img_height = attrs["img_height"]
     anchors = attrs["anchors"]
     class_num = attrs["class_num"]
-    an_num = len(anchors) / 2
+    an_num = len(anchors) // 2
     obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
     noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
     tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
@@ -120,7 +123,7 @@ def build_target(gtboxs, attrs, grid_size):
 
 def YoloV3Loss(x, gtbox, attrs):
     n, c, h, w = x.shape
-    an_num = len(attrs['anchors']) / 2
+    an_num = len(attrs['anchors']) // 2
     class_num = attrs["class_num"]
     x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
     pred_x = sigmoid(x[:, :, :, :, 0])
@@ -144,13 +147,6 @@ def YoloV3Loss(x, gtbox, attrs):
                           noobj_mask)
     loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand,
                      obj_mask_expand)
-    # print "loss_x: ", loss_x
-    # print "loss_y: ", loss_y
-    # print "loss_w: ", loss_w
-    # print "loss_h: ", loss_h
-    # print "loss_conf_obj: ", loss_conf_obj
-    # print "loss_conf_noobj: ", loss_conf_noobj
-    # print "loss_class: ", loss_class
 
     return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class
 
@@ -165,29 +161,35 @@ class TestYolov3LossOp(OpTest):
                                            self.gtbox_shape[:2])
 
         self.attrs = {
-            "img_height": self.img_height,
             "anchors": self.anchors,
             "class_num": self.class_num,
             "ignore_thresh": self.ignore_thresh,
         }
 
         self.inputs = {'X': x, 'GTBox': gtbox}
-        self.outputs = {'Loss': np.array([YoloV3Loss(x, gtbox, self.attrs)])}
-        print self.outputs
+        self.outputs = {
+            'Loss':
+            np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32')
+        }
 
     def test_check_output(self):
-        self.check_output(atol=1e-3)
+        place = core.CPUPlace()
+        self.check_output_with_place(place, atol=1e-3)
 
-    # def test_check_grad_normal(self):
-    #     self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61)
+    def test_check_grad_ignore_gtbox(self):
+        place = core.CPUPlace()
+        self.check_grad_with_place(
+            place, ['X'],
+            'Loss',
+            no_grad_set=set("GTBox"),
+            max_relative_error=0.1)
 
     def initTestCase(self):
-        self.img_height = 608
-        self.anchors = [10, 13, 16, 30, 33, 23]
+        self.anchors = [10, 13, 12, 12]
         self.class_num = 10
         self.ignore_thresh = 0.5
-        self.x_shape = (5, len(self.anchors) / 2 * (5 + self.class_num), 7, 7)
-        self.gtbox_shape = (5, 10, 5)
+        self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7)
+        self.gtbox_shape = (5, 5, 5)
 
 
 if __name__ == "__main__":