diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index caaf418076bdd43a2d989c0ac318dbba85fa313c..85b649b2937f6a281b9ee1fe7bae8101169f6102 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -16,7 +16,7 @@ else()
     set(multi_devices_graph_builder_deps)
 endif()
 cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
-            scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
+            scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
 cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
 cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
         simple_threadpool device_context)
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 8a28b187156dc5421999056b2ab2aa1a43a976d4..8a5327011015b419ef88d3ba9d3aa3765024fff2 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -35,22 +35,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
     const std::string &loss_var_name,
     const std::unordered_set<std::string> &params,
     const std::vector<Scope *> &local_scopes,
-    platform::NCCLContextMap *nccl_ctxs, bool distributed)
+    platform::NCCLContextMap *nccl_ctxs)
     : loss_var_name_(loss_var_name),
       places_(places),
       local_scopes_(local_scopes),
-      distributed_(distributed),
       nccl_ctxs_(nccl_ctxs) {
 #else
 MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
     const std::vector<platform::Place> &places,
     const std::string &loss_var_name,
     const std::unordered_set<std::string> &params,
-    const std::vector<Scope *> &local_scopes, bool distributed)
+    const std::vector<Scope *> &local_scopes)
     : loss_var_name_(loss_var_name),
       places_(places),
-      local_scopes_(local_scopes),
-      distributed_(distributed) {
+      local_scopes_(local_scopes) {
 #endif
   for (auto &p : params) {
     grad_names_.insert(GradVarName(p));
@@ -99,7 +97,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
 
     // append send op if program is distributed trainer main program.
     // always use the first device
-    if (is_forwarding && distributed_ && op->Type() == "send") {
+    if (!is_forwarding && op->Type() == "send") {
       auto &p = places_[0];
       auto *s = local_scopes_[0];
       size_t i = 0;
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 004d6d50ab8e21888341072782cd430f3d41c1b8..de34caab1be85eecb741a5003f026eb982e178ea 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -34,14 +34,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
                           const std::string &loss_var_name,
                           const std::unordered_set<std::string> &params,
                           const std::vector<Scope *> &local_scopes,
-                          platform::NCCLContextMap *nccl_ctxs,
-                          bool distributed = false);
+                          platform::NCCLContextMap *nccl_ctxs);
 #else
   MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
                           const std::string &loss_var_name,
                           const std::unordered_set<std::string> &params,
-                          const std::vector<Scope *> &local_scopes,
-                          bool distributed = false);
+                          const std::vector<Scope *> &local_scopes);
 #endif
 
   std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@@ -55,7 +53,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
   const std::vector<platform::Place> &places_;
   const std::vector<Scope *> &local_scopes_;
   std::unordered_set<std::string> grad_names_;
-  bool distributed_;
 
 #ifdef PADDLE_WITH_CUDA
   platform::NCCLContextMap *nccl_ctxs_;
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index c048c3865f14822be4a0015e385ea1b8e05d0ced..b4f16dba858fb279ec23a8a04257dda6651148cc 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -48,13 +48,13 @@ class ParallelExecutor {
            const std::string& fetched_var_name,
            const std::unordered_map<std::string, LoDTensor>& feed_tensors);
 
+  void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
+
  private:
   void SplitTensorToPlaces(
       const std::unordered_map<std::string, LoDTensor>& feed_tensors);
 
   ParallelExecutorPrivate* member_;
-
-  void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
 };
 
 }  // namespace framework
diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc
index f8cae6b26acf9d37ca286487065d70ede4c03120..cb5f89583436b059ac4d6509dac9f2e3868561aa 100644
--- a/paddle/fluid/operators/detail/serde_test.cc
+++ b/paddle/fluid/operators/detail/serde_test.cc
@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
   for (int i = 0; i < tensor_numel; ++i) {
     EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
   }
-  for (int64_t i = 0; i < rows2->size(); ++i) {
+  for (size_t i = 0; i < rows2->size(); ++i) {
     EXPECT_EQ(rows_data2[i], i);
   }
   EXPECT_EQ(slr2->height(), 1000);
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index 392404045578489014f2283b885c388d5a4586cf..a9a5d87d77ef9074e1d073d6b16083360110f670 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -554,6 +554,7 @@ All parameter, weight, gradient are variables in Paddle.
                                   bcast_vars, main_program, loss_var_name,
                                   scope, local_scopes, allow_op_delay);
            })
+      .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
       .def("local_scopes",
            [](ParallelExecutor &self) -> std::vector<Scope *> * {
              return &self.GetLocalScopes();
diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py
index b93f2f974ca28cfd8d03c0dbbf1d401620a15e53..a23cc9b772a8985028e05314bcc58932ec46b584 100644
--- a/python/paddle/fluid/parallel_executor.py
+++ b/python/paddle/fluid/parallel_executor.py
@@ -99,7 +99,7 @@ class ParallelExecutor(object):
         local_scopes = share_vars_from.executor.local_scopes(
         ) if share_vars_from else []
 
-        persistable_vars = [
+        self.persistable_vars = [
             v.name
             for v in filter(lambda var: var.persistable, main.list_vars())
         ]
@@ -112,7 +112,7 @@ class ParallelExecutor(object):
                 p.name for p in main.global_block().iter_parameters()
                 if not p.stop_gradient
             ]),
-            set(persistable_vars),
+            set(self.persistable_vars),
             main.desc,
             loss_name if loss_name else '',
             scope,
@@ -142,3 +142,6 @@ class ParallelExecutor(object):
         self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict)
         arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
         return [arr[i] for i in range(len(arr))]
+
+    def bcast_params(self):
+        self.executor.bcast_params(set(self.persistable_vars))