diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc
index 09e808902f8fe3a7a07153d3432866c18e81dc7c..67f9854c02fa92d0141463088915e720733306fb 100644
--- a/paddle/operators/parallel_do_op.cc
+++ b/paddle/operators/parallel_do_op.cc
@@ -17,6 +17,7 @@ limitations under the License. */
 #include "paddle/framework/executor.h"
 #include "paddle/framework/op_registry.h"
 #include "paddle/framework/threadpool.h"
+#include "paddle/operators/detail/safe_ref.h"
 
 namespace paddle {
 namespace operators {
@@ -39,8 +40,10 @@ static void SplitTensorAndMoveTensorToScopes(
     const std::vector<std::string> &names) {
   size_t num_sub_scopes = 0;
   for (auto &argu : names) {
-    auto *var = scope.FindVar(argu);
-    const auto &tensor = var->Get<LoDTensor>();
+    const auto &tensor =
+        detail::Ref(scope.FindVar(argu),
+                    "Cannot find variable %s in the parent scope", argu)
+            .Get<LoDTensor>();
     auto lod_tensors = tensor.SplitLoDTensor(places);
 
     for (auto &lod : lod_tensors) {
@@ -60,7 +63,9 @@ static void SplitTensorAndMoveTensorToScopes(
     }
 
     for (size_t i = 0; i < lod_tensors.size(); ++i) {
-      *(*sub_scopes)[i]->Var(argu)->GetMutable<LoDTensor>() = lod_tensors[i];
+      *detail::Ref(sub_scopes->at(i)->Var(argu),
+                   "Cannot find variable in the sub-scope", argu)
+           .GetMutable<LoDTensor>() = lod_tensors[i];
     }
   }
 }
@@ -287,6 +292,17 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
                         this->InputGrad(input_param, false));
       }
     }
+    auto *g_block = this->grad_block_[0];
+
+    // All variable name that needed by gradient operators
+    std::unordered_set<std::string> all_inputs_in_grad_blocks;
+
+    for (size_t i = 0; i < g_block->OpSize(); ++i) {
+      auto *op = g_block->Op(i);
+      for (auto &var_name : op->InputArgumentNames()) {
+        all_inputs_in_grad_blocks.insert(var_name);
+      }
+    }
 
     for (auto &output_param : this->OutputNames()) {
       if (output_param == kParallelScopes) {
@@ -295,8 +311,17 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
                        this->Output(output_param));
       } else {
         grad->SetInput(output_param, this->Output(output_param));
-        grad->SetInput(framework::GradVarName(output_param),
-                       this->OutputGrad(output_param));
+        std::vector<std::string> og_names;
+        for (auto &og_name : this->OutputGrad(output_param)) {
+          if (all_inputs_in_grad_blocks.count(og_name) != 0) {
+            // there are some gradient operators who need the OG. So make this
+            // OG as an input of parallel.do
+            og_names.push_back(og_name);
+          }
+          // else, there is no operator who need the OG. Do not use this OG as
+          // an input
+        }
+        grad->SetInput(framework::GradVarName(output_param), og_names);
       }
     }
     grad->SetAttrMap(this->Attrs());
diff --git a/python/paddle/v2/fluid/tests/book/CMakeLists.txt b/python/paddle/v2/fluid/tests/book/CMakeLists.txt
index a35abe3e0c436be4eaed01c9b9183344c6d3b275..dda02c03fd531445c1b33b39a6ded10921991d9c 100644
--- a/python/paddle/v2/fluid/tests/book/CMakeLists.txt
+++ b/python/paddle/v2/fluid/tests/book/CMakeLists.txt
@@ -1,9 +1,33 @@
 file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
 string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
 
-list(REMOVE_ITEM TEST_OPS test_image_classification_train)
+list(REMOVE_ITEM TEST_OPS test_image_classification_train test_recognize_digits)
 py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet)
 py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg)
+py_test(test_recognize_digits_mlp_cpu
+  SRCS test_recognize_digits.py
+  ARGS mlp)
+py_test(test_recognize_digits_mlp_cuda
+  SRCS test_recognize_digits.py
+  ARGS mlp --use_cuda)
+py_test(test_recognize_digits_conv_cpu
+  SRCS test_recognize_digits.py
+  ARGS conv)
+py_test(test_recognize_digits_conv_cuda
+  SRCS test_recognize_digits.py
+  ARGS conv --use_cuda)
+py_test(test_recognize_digits_mlp_cpu_parallel
+  SRCS test_recognize_digits.py
+  ARGS mlp --parallel)
+py_test(test_recognize_digits_mlp_cuda_parallel
+  SRCS test_recognize_digits.py
+  ARGS mlp --use_cuda --parallel)
+py_test(test_recognize_digits_conv_cpu_parallel
+  SRCS test_recognize_digits.py
+  ARGS conv --parallel)
+py_test(test_recognize_digits_conv_cuda_parallel
+  SRCS test_recognize_digits.py
+  ARGS conv --use_cuda --parallel)
 
 # default test
 foreach(src ${TEST_OPS})
diff --git a/python/paddle/v2/fluid/tests/book/__init__.py b/python/paddle/v2/fluid/tests/book/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b94a21a7e406b833797f8f521c62a2351c2bc30a
--- /dev/null
+++ b/python/paddle/v2/fluid/tests/book/__init__.py
@@ -0,0 +1,13 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7ef4046f9ff55c2cbfc28b50784b9bffb80d53
--- /dev/null
+++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py
@@ -0,0 +1,149 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import print_function
+import argparse
+import paddle.v2.fluid as fluid
+import paddle.v2 as paddle
+import sys
+import numpy
+
+
+def parse_arg():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "nn_type",
+        help="The neural network type, in ['mlp', 'conv']",
+        type=str,
+        choices=['mlp', 'conv'])
+    parser.add_argument(
+        "--parallel",
+        help='Run in parallel or not',
+        default=False,
+        action="store_true")
+    parser.add_argument(
+        "--use_cuda",
+        help="Run the program by using CUDA",
+        default=False,
+        action="store_true")
+    return parser.parse_args()
+
+
+BATCH_SIZE = 64
+
+
+def loss_net(hidden, label):
+    prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
+    loss = fluid.layers.cross_entropy(input=prediction, label=label)
+    return fluid.layers.mean(x=loss), fluid.layers.accuracy(
+        input=prediction, label=label)
+
+
+def mlp(img, label):
+    hidden = fluid.layers.fc(input=img, size=200, act='tanh')
+    hidden = fluid.layers.fc(input=hidden, size=200, act='tanh')
+    return loss_net(hidden, label)
+
+
+def conv_net(img, label):
+    conv_pool_1 = fluid.nets.simple_img_conv_pool(
+        input=img,
+        filter_size=5,
+        num_filters=20,
+        pool_size=2,
+        pool_stride=2,
+        act="relu")
+    conv_pool_2 = fluid.nets.simple_img_conv_pool(
+        input=conv_pool_1,
+        filter_size=5,
+        num_filters=50,
+        pool_size=2,
+        pool_stride=2,
+        act="relu")
+    return loss_net(conv_pool_2, label)
+
+
+def main():
+    args = parse_arg()
+    print("recognize digits with args: {0}".format(" ".join(sys.argv[1:])))
+
+    img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
+    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+
+    if args.nn_type == 'mlp':
+        net_conf = mlp
+    else:
+        net_conf = conv_net
+
+    if args.parallel:
+        places = fluid.layers.get_places()
+        pd = fluid.layers.ParallelDo(places)
+        with pd.do():
+            img_ = pd.read_input(img)
+            label_ = pd.read_input(label)
+            for o in net_conf(img_, label_):
+                pd.write_output(o)
+
+        avg_loss, acc = pd()
+        # get mean loss and acc through every devices.
+        avg_loss = fluid.layers.mean(x=avg_loss)
+        acc = fluid.layers.mean(x=acc)
+    else:
+        avg_loss, acc = net_conf(img, label)
+
+    test_program = fluid.default_main_program().clone()
+
+    optimizer = fluid.optimizer.Adam(learning_rate=0.001)
+    optimizer.minimize(avg_loss)
+
+    place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
+
+    exe = fluid.Executor(place)
+    exe.run(fluid.default_startup_program())
+
+    train_reader = paddle.batch(
+        paddle.reader.shuffle(
+            paddle.dataset.mnist.train(), buf_size=500),
+        batch_size=BATCH_SIZE)
+    test_reader = paddle.batch(
+        paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
+    feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
+
+    PASS_NUM = 100
+    for pass_id in range(PASS_NUM):
+        for batch_id, data in enumerate(train_reader()):
+            # train a mini-batch, fetch nothing
+            exe.run(feed=feeder.feed(data))
+            if (batch_id + 1) % 10 == 0:
+                acc_set = []
+                avg_loss_set = []
+                for test_data in test_reader():
+                    acc_np, avg_loss_np = exe.run(program=test_program,
+                                                  feed=feeder.feed(test_data),
+                                                  fetch_list=[acc, avg_loss])
+                    acc_set.append(float(acc_np))
+                    avg_loss_set.append(float(avg_loss_np))
+                # get test acc and loss
+                acc_val = numpy.array(acc_set).mean()
+                avg_loss_val = numpy.array(avg_loss_set).mean()
+                if float(acc_val) > 0.85:  # test acc > 85%
+                    exit(0)
+                else:
+                    print(
+                        'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
+                        format(pass_id, batch_id + 1,
+                               float(avg_loss_val), float(acc_val)))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py
deleted file mode 100644
index 4710d16c24e95a11108801a014f94687558fd91e..0000000000000000000000000000000000000000
--- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py
+++ /dev/null
@@ -1,74 +0,0 @@
-#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-import numpy as np
-import paddle.v2 as paddle
-import paddle.v2.fluid as fluid
-
-images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32')
-label = fluid.layers.data(name='label', shape=[1], dtype='int64')
-conv_pool_1 = fluid.nets.simple_img_conv_pool(
-    input=images,
-    filter_size=5,
-    num_filters=20,
-    pool_size=2,
-    pool_stride=2,
-    act="relu")
-conv_pool_2 = fluid.nets.simple_img_conv_pool(
-    input=conv_pool_1,
-    filter_size=5,
-    num_filters=50,
-    pool_size=2,
-    pool_stride=2,
-    act="relu")
-
-predict = fluid.layers.fc(input=conv_pool_2, size=10, act="softmax")
-cost = fluid.layers.cross_entropy(input=predict, label=label)
-avg_cost = fluid.layers.mean(x=cost)
-optimizer = fluid.optimizer.Adam(learning_rate=0.01)
-optimizer.minimize(avg_cost)
-
-accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
-
-BATCH_SIZE = 50
-PASS_NUM = 3
-train_reader = paddle.batch(
-    paddle.reader.shuffle(
-        paddle.dataset.mnist.train(), buf_size=500),
-    batch_size=BATCH_SIZE)
-
-place = fluid.CPUPlace()
-exe = fluid.Executor(place)
-feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
-exe.run(fluid.default_startup_program())
-
-for pass_id in range(PASS_NUM):
-    accuracy.reset(exe)
-    for data in train_reader():
-        loss, acc = exe.run(fluid.default_main_program(),
-                            feed=feeder.feed(data),
-                            fetch_list=[avg_cost] + accuracy.metrics)
-        pass_acc = accuracy.eval(exe)
-        print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" +
-              str(pass_acc))
-        # print loss, acc
-        if loss < 10.0 and pass_acc > 0.9:
-            # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
-            exit(0)
-
-    pass_acc = accuracy.eval(exe)
-    print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))
-
-exit(1)
diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
deleted file mode 100644
index 236ee4f3398538403228a00ee9c4b72c7c8231cf..0000000000000000000000000000000000000000
--- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
+++ /dev/null
@@ -1,111 +0,0 @@
-#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-import numpy as np
-import paddle.v2 as paddle
-import paddle.v2.fluid as fluid
-
-BATCH_SIZE = 128
-image = fluid.layers.data(name='x', shape=[784], dtype='float32')
-
-regularizer = fluid.regularizer.L2Decay(0.0005 * BATCH_SIZE)
-
-hidden1 = fluid.layers.fc(input=image,
-                          size=128,
-                          act='relu',
-                          param_attr=fluid.ParamAttr(
-                              regularizer=regularizer,
-                              gradient_clip=fluid.clip.ClipByValue(10)))
-
-hidden2 = fluid.layers.fc(input=hidden1,
-                          size=64,
-                          act='relu',
-                          param_attr=regularizer)
-
-predict = fluid.layers.fc(input=hidden2,
-                          size=10,
-                          act='softmax',
-                          param_attr=regularizer)
-
-label = fluid.layers.data(name='y', shape=[1], dtype='int64')
-
-cost = fluid.layers.cross_entropy(input=predict, label=label)
-avg_cost = fluid.layers.mean(x=cost)
-
-optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
-opts = optimizer.minimize(avg_cost)
-
-accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
-
-inference_program = fluid.default_main_program().clone()
-with fluid.program_guard(inference_program):
-    test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
-    test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
-    inference_program = fluid.io.get_inference_program(test_target)
-
-train_reader = paddle.batch(
-    paddle.reader.shuffle(
-        paddle.dataset.mnist.train(), buf_size=8192),
-    batch_size=BATCH_SIZE)
-
-test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
-
-place = fluid.CPUPlace()
-exe = fluid.Executor(place)
-feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
-exe.run(fluid.default_startup_program())
-
-PASS_NUM = 100
-for pass_id in range(PASS_NUM):
-    accuracy.reset(exe)
-    for data in train_reader():
-        out, acc = exe.run(fluid.default_main_program(),
-                           feed=feeder.feed(data),
-                           fetch_list=[avg_cost] + accuracy.metrics)
-        pass_acc = accuracy.eval(exe)
-
-        test_accuracy.reset(exe)
-        for data in test_reader():
-            out, acc = exe.run(inference_program,
-                               feed=feeder.feed(data),
-                               fetch_list=[avg_cost] + test_accuracy.metrics)
-
-        test_pass_acc = test_accuracy.eval(exe)
-        print("pass_id=" + str(pass_id) + " train_cost=" + str(
-            out) + " train_acc=" + str(acc) + " train_pass_acc=" + str(pass_acc)
-              + " test_acc=" + str(test_pass_acc))
-
-        if test_pass_acc > 0.7:
-            fluid.io.save_inference_model(
-                "./recognize_digits_mlp.inference.model/", ["x"], [predict],
-                exe)
-            break
-
-# Use load_inference_model to obtain the inference program desc,
-# the feed_target_names (the names of variables that will be feeded 
-# data using feed operators), and the fetch_targets (variables that 
-# we want to obtain data from using fetch operators).
-[infer_prog, feed_target_names, fetch_targets] = fluid.io.load_inference_model(
-    "./recognize_digits_mlp.inference.model/", exe)
-
-tensor_x = np.random.rand(1, 784).astype("float32")
-# Construct feed as a dictionary of {feed_target_name: feed_target_data}
-# and results will contain a list of data corresponding to fetch_targets.
-results = exe.run(infer_prog,
-                  feed={feed_target_names[0]: tensor_x},
-                  fetch_list=fetch_targets)
-print(results[0])
-
-exit(0)