diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index 05ffdefe05b3d47cf43fb5eca67761d452bcad4b..7b8bf17f27ca308184fab03d3bbf90cb1c5943ee 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -279,11 +279,20 @@ class DistributeTranspiler:
 
         grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
         param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
+        assert (len(grad_blocks) == len(param_blocks))
         # step2: Create new vars for the parameters and gradients blocks and
         # add ops to do the split.
-        grad_var_mapping = self._append_split_op(program, grad_blocks)
         param_var_mapping = self._create_vars_from_blocklist(program,
                                                              param_blocks)
+        grad_var_mapping = self._create_vars_from_blocklist(
+            program, grad_blocks, add_trainer_suffix=self.trainer_num > 1)
+        grad_param_mapping = dict()
+        for g, p in zip(grad_blocks, param_blocks):
+            g_name, g_bid, _ = g.split(":")
+            p_name, p_bid, _ = p.split(":")
+            grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] =  \
+                    param_var_mapping[p_name][int(p_bid)]
+
         rpc_client_var = program.global_block().create_var(
             name=RPC_CLIENT_VAR_NAME,
             persistable=True,
@@ -304,15 +313,21 @@ class DistributeTranspiler:
 
         # step 3.1: insert send op to send gradient vars to parameter servers
         ps_dispatcher.reset()
-        for varname, send_vars in grad_var_mapping.items():
+        send_vars = []
+        for varname, splited_vars in grad_var_mapping.items():
             index = find_op_by_output_arg(program.global_block(), varname)
-            eplist = ps_dispatcher.dispatch(send_vars)
+            eplist = ps_dispatcher.dispatch(splited_vars)
+            if len(splited_vars) > 1:
+                self._insert_split_op(program, varname, splited_vars)
+                index += 1
             program.global_block().insert_op(
-                index=index,
+                index=index + 1,
                 type="send_vars",
-                inputs={"X": send_vars},
+                inputs={"X": splited_vars},
                 outputs={"RPCClient": rpc_client_var},
                 attrs={"epmap": eplist})
+            for _, var in enumerate(splited_vars):
+                send_vars.append(var)
 
         if self.sync_mode:
             program.global_block().append_op(
@@ -322,21 +337,12 @@ class DistributeTranspiler:
                 attrs={"endpoints": pserver_endpoints})
 
         # step 3.2: insert recv op to receive parameters from parameter server
-        ps_dispatcher.reset()
         recv_vars = []
-        for b in param_blocks:
-            varname, block_id, _ = b.split(":")
-            recv_vars.append(param_var_mapping[varname][int(block_id)])
-        for b in grad_blocks:
-            varname, block_id, _ = b.split(":")
-            send_vars.append(grad_var_mapping[varname][int(block_id)])
-
+        for _, var in enumerate(send_vars):
+            recv_vars.append(grad_param_mapping[var])
+        ps_dispatcher.reset()
         eplist = ps_dispatcher.dispatch(recv_vars)
 
-        for i, ep in enumerate(eplist):
-            self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
-            self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
-
         program.global_block().append_op(
             type="recv",
             inputs={},
@@ -344,6 +350,10 @@ class DistributeTranspiler:
                      "RPCClient": rpc_client_var},
             attrs={"epmap": eplist})
 
+        for i, ep in enumerate(eplist):
+            self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
+            self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
+
         # TODO(Yancey1989): check dist lookup table
         if self.has_distributed_lookup_table:
             self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
@@ -848,6 +858,34 @@ class DistributeTranspiler:
             lod_level=var.lod_level,
             persistable=persistable)
 
+    def _insert_split_op(self, program, orig_varname, splited_vars):
+        orig_var = program.global_block().vars[orig_varname]
+        index = find_op_by_output_arg(program.global_block(), orig_varname)
+        if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
+            height_sections = []
+            for v in splited_vars:
+                height_sections.append(v.shape[0])
+            program.global_block().insert_op(
+                index=index + 1,
+                type="split_selected_rows",
+                inputs={"X": orig_var},
+                outputs={"Out": splited_vars},
+                attrs={"height_sections": height_sections})
+        elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
+            sections = []
+            for v in splited_vars:
+                sections.append(v.shape[0])
+            program.global_block().insert_op(
+                index=index + 1,
+                type="split_byref",
+                inputs={"X": orig_var},
+                outputs={"Out": splited_vars},
+                attrs={"sections": sections}  # assume split evenly
+            )
+        else:
+            AssertionError("Variable type should be in set "
+                           "[LOD_TENSOR, SELECTED_ROWS]")
+
     def _append_split_op(self, program, gradblocks):
         # Split variables that need to be split and append respective ops
         add_suffix = False
@@ -860,11 +898,13 @@ class DistributeTranspiler:
             if len(splited_vars) <= 1:
                 continue
             orig_var = program.global_block().vars[varname]
+            index = find_op_by_output_arg(program.global_block(), orig_var.name)
             if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
                 height_sections = []
                 for v in splited_vars:
                     height_sections.append(v.shape[0])
-                program.global_block().append_op(
+                program.global_block().insert_op(
+                    index=index + 1,
                     type="split_selected_rows",
                     inputs={"X": orig_var},
                     outputs={"Out": splited_vars},
@@ -873,7 +913,8 @@ class DistributeTranspiler:
                 sections = []
                 for v in splited_vars:
                     sections.append(v.shape[0])
-                program.global_block().append_op(
+                program.global_block().insert_op(
+                    index=index + 1,
                     type="split_byref",
                     inputs={"X": orig_var},
                     outputs={"Out": splited_vars},