diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc
index 0af3ba621aa5c3ee0599249ce868cfda72eddd18..880bfe3b04394f20079af98d07dc2a12db920b97 100644
--- a/paddle/fluid/operators/detection_map_op.cc
+++ b/paddle/fluid/operators/detection_map_op.cc
@@ -47,11 +47,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
                       "The shape is of Input(DetectRes) [N, 6].");
     auto label_dims = ctx->GetInputDim("Label");
-    PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
+    PADDLE_ENFORCE_EQ(label_dims.size(), 2,
                       "The rank of Input(Label) must be 2, "
                       "the shape is [N, 6].");
-    PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
-                      "The shape is of Input(Label) [N, 6].");
+    PADDLE_ENFORCE_EQ(label_dims[1], 6, "The shape is of Input(Label) [N, 6].");
 
     if (ctx->HasInput("PosCount")) {
       PADDLE_ENFORCE(ctx->HasInput("TruePos"),
@@ -96,6 +95,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
              "instance, the offsets in first dimension are called LoD, "
              "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
              "means there is no ground-truth data.");
+    AddInput("HasState",
+             "(Tensor<int>) A tensor with shape [1], 0 means ignoring input "
+             "states, which including PosCount, TruePos, FalsePos.")
+        .AsDispensable();
     AddInput("PosCount",
              "(Tensor) A tensor with shape [Ncls, 1], store the "
              "input positive example count of each class, Ncls is the count of "
@@ -145,7 +148,7 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
         "(float) "
         "The lower bound jaccard overlap threshold of detection output and "
         "ground-truth data.")
-        .SetDefault(.3f);
+        .SetDefault(.5f);
     AddAttr<bool>("evaluate_difficult",
                   "(bool, default true) "
                   "Switch to control whether the difficult data is evaluated.")
diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h
index 92e05108393f9c8fa9259d16d2dd4c6685f2c1e2..b2b0995b35bf16432e73ebcfd3341adb00c11fd8 100644
--- a/paddle/fluid/operators/detection_map_op.h
+++ b/paddle/fluid/operators/detection_map_op.h
@@ -87,7 +87,13 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
     std::map<int, std::vector<std::pair<T, int>>> true_pos;
     std::map<int, std::vector<std::pair<T, int>>> false_pos;
 
-    if (in_pos_count != nullptr) {
+    auto* has_state = ctx.Input<framework::LoDTensor>("HasState");
+    int state = 0;
+    if (has_state) {
+      state = has_state->data<int>()[0];
+    }
+
+    if (in_pos_count != nullptr && state) {
       GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
                   true_pos, false_pos);
     }
@@ -202,6 +208,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
 
     int* pos_count_data = output_pos_count.mutable_data<int>(
         framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace());
+
     T* true_pos_data = output_true_pos.mutable_data<T>(
         framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
     T* false_pos_data = output_false_pos.mutable_data<T>(
diff --git a/python/paddle/fluid/evaluator.py b/python/paddle/fluid/evaluator.py
index 38c6a98279021a1a28a46078a0af47ea97bd4aeb..364789233bf9cd8667b2e0c0b927adee9dd9a65c 100644
--- a/python/paddle/fluid/evaluator.py
+++ b/python/paddle/fluid/evaluator.py
@@ -18,11 +18,13 @@ import layers
 from framework import Program, Variable, program_guard
 import unique_name
 from layer_helper import LayerHelper
+from initializer import Constant
 
 __all__ = [
     'Accuracy',
     'ChunkEvaluator',
     'EditDistance',
+    'DetectionMAP',
 ]
 
 
@@ -285,3 +287,120 @@ class EditDistance(Evaluator):
             result = executor.run(
                 eval_program, fetch_list=[avg_distance, avg_instance_error])
         return np.array(result[0]), np.array(result[1])
+
+
+class DetectionMAP(Evaluator):
+    """
+    Calculate the detection mean average precision (mAP).
+
+    TODO (Dang Qingqing): update the following doc.
+    The general steps are as follows:
+    1. calculate the true positive and false positive according to the input
+        of detection and labels.
+    2. calculate mAP value, support two versions: '11 point' and 'integral'.
+
+    Please get more information from the following articles:
+      https://sanchom.wordpress.com/tag/average-precision/
+      https://arxiv.org/abs/1512.02325
+
+    Args:
+        input (Variable): The detection results, which is a LoDTensor with shape
+            [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
+        gt_label (Variable): The ground truth label index, which is a LoDTensor
+            with shape [N, 1]. 
+        gt_difficult (Variable): Whether this ground truth is a difficult
+            bounding box (bbox), which is a LoDTensor [N, 1].
+        gt_box (Variable): The ground truth bounding box (bbox), which is a
+            LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
+        overlap_threshold (float): The threshold for deciding true/false
+            positive, 0.5 by defalut.
+        evaluate_difficult (bool): Whether to consider difficult ground truth
+            for evaluation, True by defalut.
+        ap_version (string): The average precision calculation ways, it must be
+            'integral' or '11point'. Please check
+            https://sanchom.wordpress.com/tag/average-precision/ for details.
+            - 11point: the 11-point interpolated average precision.
+            - integral: the natural integral of the precision-recall curve.
+
+    Example:
+
+        exe = fluid.executor(place)
+        map_evaluator = fluid.Evaluator.DetectionMAP(input,
+            gt_label, gt_difficult, gt_box)
+        cur_map, accum_map = map_evaluator.get_map_var()
+        fetch = [cost, cur_map, accum_map]
+        for epoch in PASS_NUM:
+            map_evaluator.reset(exe)
+            for data in batches:
+                loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
+
+        In the above example:
+
+        'cur_map_v' is the mAP of current mini-batch.
+        'accum_map_v' is the accumulative mAP of one pass.
+    """
+
+    def __init__(self,
+                 input,
+                 gt_label,
+                 gt_box,
+                 gt_difficult,
+                 overlap_threshold=0.5,
+                 evaluate_difficult=True,
+                 ap_version='integral'):
+        super(DetectionMAP, self).__init__("map_eval")
+
+        gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
+        gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
+        label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
+
+        # calculate mean average precision (mAP) of current mini-batch
+        map = layers.detection_map(
+            input,
+            label,
+            overlap_threshold=overlap_threshold,
+            evaluate_difficult=evaluate_difficult,
+            ap_version=ap_version)
+
+        self.create_state(dtype='int32', shape=None, suffix='accum_pos_count')
+        self.create_state(dtype='float32', shape=None, suffix='accum_true_pos')
+        self.create_state(dtype='float32', shape=None, suffix='accum_false_pos')
+
+        self.has_state = None
+        var = self.helper.create_variable(
+            persistable=True, dtype='int32', shape=[1])
+        self.helper.set_variable_initializer(
+            var, initializer=Constant(value=int(0)))
+        self.has_state = var
+
+        # calculate accumulative mAP
+        accum_map = layers.detection_map(
+            input,
+            label,
+            overlap_threshold=overlap_threshold,
+            evaluate_difficult=evaluate_difficult,
+            has_state=self.has_state,
+            input_states=self.states,
+            out_states=self.states,
+            ap_version=ap_version)
+
+        layers.fill_constant(
+            shape=self.has_state.shape,
+            value=1,
+            dtype=self.has_state.dtype,
+            out=self.has_state)
+
+        self.cur_map = map
+        self.accum_map = accum_map
+
+    def get_map_var(self):
+        return self.cur_map, self.accum_map
+
+    def reset(self, executor, reset_program=None):
+        if reset_program is None:
+            reset_program = Program()
+        with program_guard(main_program=reset_program):
+            var = _clone_var_(reset_program.current_block(), self.has_state)
+            layers.fill_constant(
+                shape=var.shape, value=0, dtype=var.dtype, out=var)
+        executor.run(reset_program)
diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py
index d16b4dc3a482d657158636f0e3a2c1d07bae7d8c..420d3de7f73aaad2ffbd4f285a2e28672fd853f1 100644
--- a/python/paddle/fluid/layers/detection.py
+++ b/python/paddle/fluid/layers/detection.py
@@ -151,23 +151,34 @@ def detection_output(loc,
 @autodoc()
 def detection_map(detect_res,
                   label,
-                  pos_count=None,
-                  true_pos=None,
-                  false_pos=None,
                   overlap_threshold=0.3,
                   evaluate_difficult=True,
-                  ap_type='integral'):
+                  has_state=None,
+                  input_states=None,
+                  out_states=None,
+                  ap_version='integral'):
     helper = LayerHelper("detection_map", **locals())
 
-    map_out = helper.create_tmp_variable(dtype='float32')
-    accum_pos_count_out = helper.create_tmp_variable(dtype='int32')
-    accum_true_pos_out = helper.create_tmp_variable(dtype='float32')
-    accum_false_pos_out = helper.create_tmp_variable(dtype='float32')
+    def __create_var(type):
+        return helper.create_tmp_variable(dtype=type)
+
+    map_out = __create_var('float32')
+    accum_pos_count_out = out_states[0] if out_states else __create_var('int32')
+    accum_true_pos_out = out_states[1] if out_states else __create_var(
+        'float32')
+    accum_false_pos_out = out_states[2] if out_states else __create_var(
+        'float32')
+
+    pos_count = input_states[0] if input_states else None
+    true_pos = input_states[1] if input_states else None
+    false_pos = input_states[2] if input_states else None
+
     helper.append_op(
         type="detection_map",
         inputs={
             'Label': label,
             'DetectRes': detect_res,
+            'HasState': has_state,
             'PosCount': pos_count,
             'TruePos': true_pos,
             'FalsePos': false_pos
@@ -181,9 +192,9 @@ def detection_map(detect_res,
         attrs={
             'overlap_threshold': overlap_threshold,
             'evaluate_difficult': evaluate_difficult,
-            'ap_type': ap_type
+            'ap_type': ap_version
         })
-    return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out
+    return map_out
 
 
 def bipartite_match(dist_matrix,
diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py
index 0d2d653c01cb994bb10ea6bc5c6c8b8ce5004d7c..b183db55b2afa23136a07a5dbc8d0925d0868f95 100644
--- a/python/paddle/fluid/tests/test_detection.py
+++ b/python/paddle/fluid/tests/test_detection.py
@@ -158,26 +158,9 @@ class TestDetectionMAP(unittest.TestCase):
                 append_batch_size=False,
                 dtype='float32')
 
-            map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out = layers.detection_map(
-                detect_res=detect_res, label=label)
+            map_out = layers.detection_map(detect_res=detect_res, label=label)
             self.assertIsNotNone(map_out)
-            self.assertIsNotNone(accum_pos_count_out)
-            self.assertIsNotNone(accum_true_pos_out)
-            self.assertIsNotNone(accum_false_pos_out)
             self.assertEqual(map_out.shape, (1, ))
-            map_out, accum_pos_count_out2, accum_true_pos_out2, accum_false_pos_out2 = layers.detection_map(
-                detect_res=detect_res, label=label)
-            self.assertIsNotNone(map_out)
-            self.assertIsNotNone(accum_pos_count_out2)
-            self.assertIsNotNone(accum_true_pos_out2)
-            self.assertIsNotNone(accum_false_pos_out2)
-            self.assertEqual(map_out.shape, (1, ))
-            self.assertEqual(accum_pos_count_out.shape,
-                             accum_pos_count_out2.shape)
-            self.assertEqual(accum_true_pos_out.shape,
-                             accum_true_pos_out2.shape)
-            self.assertEqual(accum_false_pos_out.shape,
-                             accum_false_pos_out2.shape)
         print(str(program))
 
 
diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py
index 70ccd885d89f245df492bad0fbcecc093dc1928c..9857cc58456b51f99bd06aa4496298d33c4abcd6 100644
--- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py
+++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py
@@ -34,10 +34,12 @@ class TestDetectionMAPOp(OpTest):
                 'int32')
             self.true_pos = np.array(self.true_pos).astype('float32')
             self.false_pos = np.array(self.false_pos).astype('float32')
+            self.has_state = np.array([1]).astype('int32')
 
             self.inputs = {
                 'Label': (self.label, self.label_lod),
                 'DetectRes': (self.detect, self.detect_lod),
+                'HasState': self.has_state,
                 'PosCount': self.class_pos_count,
                 'TruePos': (self.true_pos, self.true_pos_lod),
                 'FalsePos': (self.false_pos, self.false_pos_lod)