diff --git a/paddle/operators/target_assign_op.cc b/paddle/operators/target_assign_op.cc
index 9c7d625136be757da3ab6384bcfbbab72e697682..615ca857ceb45d442b75fffc6662cc2bda19562d 100644
--- a/paddle/operators/target_assign_op.cc
+++ b/paddle/operators/target_assign_op.cc
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -61,10 +61,12 @@ class TargetAssignOp : public framework::OperatorWithKernel {
                       "The rank of Input(NegIndices) must be 2.");
 
     PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0],
-                      "The 1st dimension of Input(EncodedGTBBox) and "
-                      "Input(GTScoreLabel) must be the same.");
+                      "The 1st dimension (means the total number of "
+                      "ground-truth bounding boxes) of Input(EncodedGTBBox) "
+                      "and Input(GTScoreLabel) must be the same.");
     PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1],
-                      "The 2nd dimension of Input(EncodedGTBBox) and "
+                      "The 2nd dimension (means the number of priod boxes) "
+                      "of Input(EncodedGTBBox) and "
                       "Input(MatchIndices) must be the same.");
     PADDLE_ENFORCE_EQ(blabel_dims[2], 4,
                       "The 3rd dimension of Input(EncodedGTBBox) must be 4.");
@@ -101,31 +103,31 @@ class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
              "labels with shape [Ng, 1], where the Ng is the same as it in "
              "the input of EncodedGTBBox.");
     AddInput("MatchIndices",
-             "(Tensor, default LoDTensor<int>), The input matched indices "
+             "(Tensor, default Tensor<int>), The input matched indices "
              "with shape [N, Np], where N is the batch size, Np is the same "
              "as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
              "is -1, the j-th prior box is not matched to any ground-truh "
              "box in i-th instance.");
     AddInput("NegIndices",
              "(LoDTensor, default LoDTensor<int>), The input negative example "
-             "indics with shape [Neg, 1], where is the total number of "
+             "indices with shape [Neg, 1], where is the total number of "
              "negative example indices.");
     AddAttr<int>("background_label",
-                 "(int, default 0), Label id for background class.")
+                 "(int, default 0), Label index of background class.")
         .SetDefault(0);
     AddOutput("PredBBoxLabel",
               "(Tensor), The output encoded ground-truth labels "
               "with shape [N, Np, 4], N is the batch size and Np, 4 is the "
               "same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
               "is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
-              "box for background_label_id in i-th instance.");
+              "box for background_label in i-th instance.");
     AddOutput("PredBBoxWeight",
               "(Tensor), The weight for PredBBoxLabel with the shape "
               "of [N, Np, 1]");
     AddOutput("PredScoreLabel",
               "(Tensor, default Tensor<int>), The output score labels for "
               "each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
-              "is -1, PredScoreLabel[i][j] = background_label_id.");
+              "is -1, PredScoreLabel[i][j] = background_label.");
     AddOutput("PredScoreWeight",
               "(Tensor), The weight for PredScoreLabel with the shape "
               "of [N, Np, 1]");
@@ -136,19 +138,47 @@ and regression targets to each prior box as well as weights to each
 prior box. The weights is used to specify which prior box would not contribute
 to training loss.
 
-TODO(dang qingqing) add an example.
+For each instance, the output `PredBBoxLabel`, `PredBBoxWeight`,
+`PredScoreLabel` and `PredScoreWeight` are assigned based on `MatchIndices`.
+Assumed that the row offset for each instance in `EncodedGTBBox` is called lod,
+this operato assigns classification/regression targets by performing the
+following steps:
+
+1. Assigning all outpts based on `MatchIndices`:
+
+If id = MatchIndices[i][j] > 0,
+
+    PredBBoxLabel[i][j] = EncodedGTBBox[lod[i] + id][j]
+    PredBBoxWeight[i][j] = 1.
+    PredScoreLabel[i][j] = GTScoreLabel[lod[i] + id]
+    PredScoreWeight[i][j] = 1.
+
+Otherwise, 
+
+    PredBBoxLabel[j][j] = [0., 0., 0., 0.]
+    PredBBoxWeight[i][j] = 0.
+    PredScoreLabel[i][j] = background_label
+    PredScoreWeight[i][j] = 0.
+
+2. Assigning PredScoreWeight based on `NegIndices`:
+
+Assumed that the row offset for each instance in `NegIndices` is caleed neg_lod,
+for i-th instance and all ids of NegIndices in this instance:
+
+    PredScoreLabel[i][id] = background_label
+    PredScoreWeight[i][id] = 1.0
 
     )DOC");
   }
 };
 
 template <typename T>
-struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> {
+struct NegTargetAssignFunctor<platform::CPUDeviceContext, T> {
   void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices,
                   const size_t* lod, const int num, const int num_prior_box,
                   const int background_label, int* out_label, T* out_label_wt) {
     for (int i = 0; i < num; ++i) {
-      for (int j = lod[i]; j < lod[i + 1]; ++j) {
+      for (size_t j = lod[i]; j < lod[i + 1]; ++j) {
         int id = neg_indices[j];
         out_label[i * num_prior_box + id] = background_label;
         out_label_wt[i * num_prior_box + id] = static_cast<T>(1.0);
@@ -157,8 +187,8 @@ struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> {
   }
 };
 
-template struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, float>;
-template struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, double>;
+template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float>;
+template struct NegTargetAssignFunctor<platform::CPUDeviceContext, double>;
 
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/operators/target_assign_op.cu b/paddle/operators/target_assign_op.cu
index c04de86ec58ccaac33dc0862988e8a35a7d5ed65..fc0a1000a4202adeca3e0d6fbb05e832a79dbaba 100644
--- a/paddle/operators/target_assign_op.cu
+++ b/paddle/operators/target_assign_op.cu
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -18,38 +18,38 @@ namespace paddle {
 namespace operators {
 
 template <typename T>
-__global__ void UpdateTargetLabelKernel(const int* neg_indices,
-                                        const size_t* lod, const int num,
-                                        const int num_prior_box,
-                                        const int background_label,
-                                        int* out_label, T* out_label_wt) {
+__global__ void NegTargetAssignKernel(const int* neg_indices, const size_t* lod,
+                                      const int num, const int num_prior_box,
+                                      const int background_label,
+                                      int* out_label, T* out_label_wt) {
   int bidx = blockIdx.x;
   int st = lod[bidx];
   int ed = lod[bidx + 1];
 
+  int row_start = bidx * num_prior_box;
   for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
-    int id = neg_indices[i];
-    out_label[bidx * num_prior_box + id] = background_label;
-    out_label_wt[bidx * num_prior_box + id] = 1.;
+    int id = row_start + neg_indices[i];
+    out_label[id] = background_label;
+    out_label_wt[id] = 1.;
   }
 }
 
 template <typename T>
-struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, T> {
+struct NegTargetAssignFunctor<platform::CUDADeviceContext, T> {
   void operator()(const platform::CUDADeviceContext& ctx,
                   const int* neg_indices, const size_t* lod, const int num,
                   const int num_prior_box, const int background_label,
                   int* out_label, T* out_label_wt) {
     const int block_size = 256;
     const int grid_size = num;
-    UpdateTargetLabelKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>(
+    NegTargetAssignKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>(
         neg_indices, lod, num, num_prior_box, background_label, out_label,
         out_label_wt);
   }
 };
 
-template struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, float>;
-template struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, double>;
+template struct NegTargetAssignFunctor<platform::CUDADeviceContext, float>;
+template struct NegTargetAssignFunctor<platform::CUDADeviceContext, double>;
 
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/operators/target_assign_op.h b/paddle/operators/target_assign_op.h
index 267bdbf1effa6d11bf587336d67d6fffaded08f6..82fca5724c0bd9fbfb60a98b91944700bfab9cdf 100644
--- a/paddle/operators/target_assign_op.h
+++ b/paddle/operators/target_assign_op.h
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -56,40 +56,41 @@ struct TargetAssignFunctor {
     int row = i / num_prior_box_;
     int col = i - row * num_prior_box_;
 
-    size_t off = lod_[row];
+    size_t row_off = lod_[row];
+    int offset = row * num_prior_box_ + col;
 
-    int id = match_indices_[row * num_prior_box_ + col];
-    T* obox = out_box_ + (row * num_prior_box_ + col) * 4;
-    int* olabel = out_label_ + row * num_prior_box_ + col;
-    T* obox_wt = out_box_wt_ + row * num_prior_box_ + col;
-    T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col;
+    int id = match_indices_[offset];
+    T* obox = out_box_ + offset * 4;
+    int* olabel = out_label_ + offset;
+    T* obox_wt = out_box_wt_ + offset;
+    T* olabel_wt = out_label_wt_ + offset;
 
     if (id > -1) {
-      const T* gtbox = gt_box_ + ((off + id) * num_prior_box_ + col) * 4;
+      const T* gtbox = gt_box_ + ((row_off + id) * num_prior_box_ + col) * 4;
 
       obox[0] = gtbox[0];
       obox[1] = gtbox[1];
       obox[2] = gtbox[2];
       obox[3] = gtbox[3];
 
-      olabel[0] = gt_label_[off + id];
-      obox_wt[0] = 1.;
-      olabel_wt[0] = 1.;
+      olabel[0] = gt_label_[row_off + id];
+      obox_wt[0] = static_cast<T>(1.);
+      olabel_wt[0] = static_cast<T>(1.);
     } else {
-      obox[0] = 0.;
-      obox[1] = 0.;
-      obox[2] = 0.;
-      obox[3] = 0.;
+      obox[0] = static_cast<T>(0.);
+      obox[1] = static_cast<T>(0.);
+      obox[2] = static_cast<T>(0.);
+      obox[3] = static_cast<T>(0.);
 
       olabel[0] = background_label_;
-      obox_wt[0] = 0.;
-      olabel_wt[0] = 0.;
+      obox_wt[0] = static_cast<T>(0.);
+      olabel_wt[0] = static_cast<T>(0.);
     }
   }
 };
 
 template <typename DeviceContext, typename T>
-struct UpdateTargetLabelFunctor {
+struct NegTargetAssignFunctor {
   void operator()(const platform::DeviceContext& ctx, const int* neg_indices,
                   const size_t* lod, const int num, const int num_prior_box,
                   const int background_label, int* out_label,
@@ -130,7 +131,11 @@ class TargetAssignKernel : public framework::OpKernel<T> {
     int64_t num_prior_box = match_indices->dims()[1];
 
     auto gt_lod = enc_gt_box->lod().back();
+    auto gt_label_lod = gt_label->lod().back();
     auto neg_lod = neg_indices->lod().back();
+    for (size_t i = 0; i < gt_lod.size(); ++i) {
+      PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]);
+    }
 
     size_t* gt_lod_data = gt_lod.data(ctx.GetPlace());
     size_t* neg_lod_data = neg_lod.data(ctx.GetPlace());
@@ -145,9 +150,9 @@ class TargetAssignKernel : public framework::OpKernel<T> {
                                                 num * num_prior_box);
     for_range(functor);
 
-    UpdateTargetLabelFunctor<DeviceContext, T> update_functor;
-    update_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box,
-                   background_label, olabel_data, olabel_wt_data);
+    NegTargetAssignFunctor<DeviceContext, T> neg_trg_functor;
+    neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box,
+                    background_label, olabel_data, olabel_wt_data);
   }
 };
 
diff --git a/python/paddle/v2/fluid/tests/test_target_assign_op.py b/python/paddle/v2/fluid/tests/test_target_assign_op.py
index 49edff5c7fd46e152984857284ddec894ad88fc9..8a1155c6217401b1b85e3c0bdc47f438f482bcbb 100755
--- a/python/paddle/v2/fluid/tests/test_target_assign_op.py
+++ b/python/paddle/v2/fluid/tests/test_target_assign_op.py
@@ -14,8 +14,6 @@
 
 import unittest
 import numpy as np
-import math
-import sys
 import random
 from op_test import OpTest
 
@@ -89,8 +87,6 @@ class TestTargetAssginOp(OpTest):
         num_class = 21
         gt_lod = [0, 5, 11, 23]
         neg_lod = [0, 4, 7, 13]
-        #gt_lod = [0, 2, 5]
-        #neg_lod = [0, 2, 4]
         batch_size = len(gt_lod) - 1
         num_gt = gt_lod[-1]
         background_label = 0