diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
index 02a1b421526dfc88a39edf1b82c394f2c816187a..56429b748064daeac2780d5414513fffa9003b58 100755
--- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
@@ -526,7 +526,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
 
         self.set_virtual_pipeline_rank(0)
         self.input_tensors[0].append(
-            p2p.recv_forward(self.is_pipeline_first_stage()))
+            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False))
 
         # run startup steps
         for micro_step in range(startup_steps):
@@ -647,7 +647,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
         if not forward_only:
             if all_startup_steps:
                 self.output_tensor_grads[self.num_model_chunks - 1].append(
-                    p2p.recv_backward(self.is_pipeline_last_stage()))
+                    p2p.recv_backward(self.is_pipeline_last_stage(),
+                                      sync_recv=False))
 
             for micro_step in range(steady_steps, num_steps):
                 # cooldown loop
diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
index 160c5f1511220aa8ec36f670d10f96f1b5e90cd9..7962e2dd4373e643a4510bc5c56bf53b3c02f88e 100644
--- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
@@ -207,6 +207,7 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
                      rank_id):
     src_rank_in_group = src if group is None else group.get_group_rank(src)
     if _in_legacy_dygraph():
+        assert use_calc_stream
         return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id,
                                           'peer', src_rank_in_group, 'num',
@@ -216,8 +217,11 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
     elif in_dygraph_mode():
         group = paddle.distributed.collective._get_default_group(
         ) if group is None else group
-        return group.process_group.recv_partial(tensor, src_rank_in_group,
+        task = group.process_group.recv_partial(tensor, src_rank_in_group,
                                                 nranks, rank_id)
+        if use_calc_stream:
+            task.wait()
+        return task
 
 
 def recv_partial(tensor,
@@ -238,7 +242,7 @@ def recv_partial(tensor,
         return _partial_recv_op(tensor, group, use_calc_stream, ring_id,
                                 src_rank, nranks, rank_id)
     else:
-        if _in_legacy_dygraph():
+        if _in_legacy_dygraph() or use_calc_stream:
             recv_op = paddle.distributed.recv
         elif in_dygraph_mode():
             recv_op = paddle.distributed.irecv
@@ -275,7 +279,11 @@ def allgather_partial(tensor,
                                  nranks, rank_id)
 
 
-def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
+def _p2p_helper(tensor_send_next,
+                tensor_send_prev,
+                recv_prev,
+                recv_next,
+                sync_recv=True):
     global _hcg
 
     tensor_recv_prev = None
@@ -354,7 +362,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
                                  nranks=mp_degree,
                                  rank_id=mp_rank,
                                  group=_hcg.recv_prev_group,
-                                 use_calc_stream=True))
+                                 use_calc_stream=sync_recv))
         else:
             tasks.append(
                 recv_partial(tensor_recv_prev,
@@ -362,7 +370,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
                              nranks=mp_degree,
                              rank_id=mp_rank,
                              group=_hcg.recv_prev_group,
-                             use_calc_stream=True))
+                             use_calc_stream=sync_recv))
 
     if tensor_send_next is not None:
         if isinstance(tensor_send_next, tuple):
@@ -394,7 +402,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
                                  nranks=mp_degree,
                                  rank_id=mp_rank,
                                  group=_hcg.recv_next_group,
-                                 use_calc_stream=True))
+                                 use_calc_stream=sync_recv))
 
         else:
             tasks.append(
@@ -403,10 +411,10 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
                              nranks=mp_degree,
                              rank_id=mp_rank,
                              group=_hcg.recv_next_group,
-                             use_calc_stream=True))
+                             use_calc_stream=sync_recv))
 
-    if in_dygraph_mode():
-        # wait isend/irecv tasks in eager dygraph mode with new comm library
+    if not sync_recv and in_dygraph_mode():
+        # wait irecv tasks in eager dygraph mode with new comm library
         for task in tasks:
             assert task is not None
             task.wait()
@@ -443,7 +451,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
     return tensor_recv_prev, tensor_recv_next
 
 
-def recv_forward(pp_first_stage):
+def recv_forward(pp_first_stage, sync_recv=True):
     if pp_first_stage:
         input_tensor = None
     else:
@@ -454,18 +462,20 @@ def recv_forward(pp_first_stage):
         input_tensor, _ = _p2p_helper(tensor_send_next=None,
                                       tensor_send_prev=None,
                                       recv_prev=True,
-                                      recv_next=False)
+                                      recv_next=False,
+                                      sync_recv=sync_recv)
     return input_tensor
 
 
-def recv_backward(pp_last_stage):
+def recv_backward(pp_last_stage, sync_recv=True):
     if pp_last_stage:
         output_tensor_grad = None
     else:
         _, output_tensor_grad = _p2p_helper(tensor_send_next=None,
                                             tensor_send_prev=None,
                                             recv_prev=False,
-                                            recv_next=True)
+                                            recv_next=True,
+                                            sync_recv=sync_recv)
     return output_tensor_grad
 
 
@@ -527,7 +537,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
         tensor_send_next=output_tensor,
         tensor_send_prev=input_tensor_grad,
         recv_prev=recv_prev,
-        recv_next=recv_next)
+        recv_next=recv_next,
+        sync_recv=False)
     return input_tensor, output_tensor_grad
 
 
@@ -544,7 +555,8 @@ def send_forward_recv_forward(output_tensor, recv_prev):
     input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor,
                                   tensor_send_prev=None,
                                   recv_prev=recv_prev,
-                                  recv_next=False)
+                                  recv_next=False,
+                                  sync_recv=False)
 
     return input_tensor
 
@@ -553,5 +565,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next):
     _, output_tensor_grad = _p2p_helper(tensor_send_next=None,
                                         tensor_send_prev=input_tensor_grad,
                                         recv_prev=False,
-                                        recv_next=recv_next)
+                                        recv_next=recv_next,
+                                        sync_recv=False)
     return output_tensor_grad