diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto
index 805ef1c3e91e4d2cb892521242ab530a43c16fd6..6363eedc80a20ff05ff88856f4218c2613a85c58 100755
--- a/paddle/fluid/framework/distributed_strategy.proto
+++ b/paddle/fluid/framework/distributed_strategy.proto
@@ -29,14 +29,18 @@ message RecomputeConfig {
 }
 
 message ShardingConfig {
-  optional float segment_broadcast_MB = 1 [ default = 32.0 ];
-  optional bool hybrid_dp = 2 [ default = false ];
-  optional int32 sharding_degree = 3 [ default = 8 ];
-  optional int32 mp_degree = 4 [ default = 1 ];
-  optional string sharding_segment_strategy = 5
+  optional string sharding_segment_strategy = 1
       [ default = 'segment_broadcast_MB' ];
-  repeated string segment_anchors = 6;
-  optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
+  optional float segment_broadcast_MB = 2 [ default = 32.0 ];
+  repeated string segment_anchors = 3;
+  optional int32 sharding_degree = 4 [ default = 8 ];
+  optional int32 mp_degree = 5 [ default = 1 ];
+  optional int32 dp_degree = 6 [ default = 1 ];
+  optional bool hybrid_dp = 7 [ default = false ];
+  optional int32 gradient_merge_acc_step = 8 [ default = 1 ];
+  optional bool optimize_offload = 9 [ default = false ];
+  optional bool pp_allreduce_in_optimize = 10 [ default = false ];
+  optional int32 pp_degree = 11 [ default = 1 ];
 }
 
 message AMPConfig {
diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
old mode 100644
new mode 100755
index 6cb7593b6bf7c7a252bfddea850aaf243304b187..ae2daa9b9d8592a5be8ea57919c267a0c2a669d1
--- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
             'accumulate_steps']
         self.schedule_mode = user_defined_strategy.pipeline_configs[
             'schedule_mode']
+        self.use_sharding = user_defined_strategy.sharding
 
     def _can_apply(self):
         if not self.role_maker._is_collective:
             return False
 
+        # FIXME revise for hybrid parallelism
+        if self.use_sharding:
+            return False
+
         if self.user_defined_strategy.pipeline == True:
             return True
         return False
diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
index cf399f66946bd7eba79902e0dcfe5873d86a1ec1..40ba77815663f0bcb1cfcdb1d4562a1a7579424f 100755
--- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
@@ -81,7 +81,10 @@ class FP16Utils(object):
             if not FP16Utils.is_fp32_cast_op(block, op):
                 continue
             output_name = op.desc.output_arg_names()[0]
-            param_name = output_name.strip("@GRAD")
+            # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+            param_name = output_name.strip(
+                "@GRAD@MERGED"
+            ) if "@MERGED" in output_name else output_name.strip("@GRAD")
             if param_name not in shard.global_params:
                 raise ValueError("Output 'X' of cast_op must be a grad of"
                                  "model param, but {} is not a grad".format(
@@ -105,7 +108,11 @@ class FP16Utils(object):
                 reversed_x = []
                 reversed_x_paramname = []
                 for input_name in op.desc.input('X'):
-                    param_name = input_name.strip("@GRAD")
+                    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+                    if "@MERGED" in input_name:
+                        param_name = input_name.strip("@GRAD@MERGED")
+                    else:
+                        param_name = input_name.strip("@GRAD")
                     if param_name not in shard.global_params:
                         raise ValueError(
                             "Input 'X' of check_finite_and_unscale must"
@@ -169,3 +176,58 @@ class FP16Utils(object):
                 OP_ROLE_KEY: OpRole.Optimize
             })
         block._sync_with_cpp()
+
+    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+    @staticmethod
+    def sync_amp_check_nan_inf(block, ring_id):
+        update_loss_scaling_op_idx = -1
+
+        for idx, op in reversed(list(enumerate(block.ops))):
+            if op.type == "update_loss_scaling":
+                update_loss_scaling_op_idx = idx
+                inf_var_name = op.desc.input('FoundInfinite')[0]
+                op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
+
+        # not use amp
+        if update_loss_scaling_op_idx == -1:
+            return
+        inf_var = block.var(inf_var_name)
+        inf_var_int32 = block.create_var(
+            name=inf_var_name + "@cast_int32",
+            shape=inf_var.shape,
+            dtype=core.VarDesc.VarType.INT32)
+        inf_var_global = block.create_var(
+            name=inf_var_name + "@GLOBAL_WORLD",
+            shape=inf_var.shape,
+            dtype=inf_var.dtype)
+        block._insert_op_without_sync(
+            update_loss_scaling_op_idx,
+            type='cast',
+            inputs={'X': inf_var},
+            outputs={'Out': inf_var_int32},
+            attrs={
+                "in_dtype": inf_var.dtype,
+                "out_dtype": inf_var_int32.dtype,
+                OP_ROLE_KEY: OpRole.Optimize
+            })
+        block._insert_op_without_sync(
+            update_loss_scaling_op_idx + 1,
+            type='c_allreduce_max',
+            inputs={'X': inf_var_int32},
+            outputs={'Out': inf_var_int32},
+            attrs={
+                'ring_id': ring_id,
+                'use_calc_stream': True,
+                OP_ROLE_KEY: OpRole.Optimize
+            })
+        block._insert_op_without_sync(
+            update_loss_scaling_op_idx + 2,
+            type='cast',
+            inputs={'X': inf_var_int32},
+            outputs={'Out': inf_var_global},
+            attrs={
+                "in_dtype": inf_var_int32.dtype,
+                "out_dtype": inf_var_global.dtype,
+                OP_ROLE_KEY: OpRole.Optimize
+            })
+        block._sync_with_cpp()
diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
index 5082bc33830198187001e60a4b73cac8bd8f0f22..d5a012b147a99ef6b6035b1b317e9cc99b5c1d93 100755
--- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
@@ -32,6 +32,7 @@ class GradientClipHelper(object):
         deperated_vars = set()
         deperate_op_idx = set()
         reversed_x_paramname = []
+        global_norm_sum_op_idx = -1
         for idx, op in enumerate(block.ops):
             if not self._is_gradient_clip_op(op):
                 continue
@@ -41,7 +42,11 @@ class GradientClipHelper(object):
             for input_name in op.desc.input_arg_names():
                 if input_name in deperated_vars:
                     deperate_op = True
-                param_name = input_name.strip("@GRAD")
+                # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+                if "@MERGED" in input_name:
+                    param_name = input_name.strip("@GRAD@MERGED")
+                else:
+                    param_name = input_name.strip("@GRAD")
                 if shard.is_param(param_name) and \
                   not shard.has_param(param_name):
                     deperate_op = True
@@ -51,7 +56,8 @@ class GradientClipHelper(object):
             if deperate_op:
                 deperate_op_idx.add(idx)
                 for output_name in op.desc.output_arg_names():
-                    deperated_vars.add(output_name)
+                    if output_name not in op.desc.input_arg_names():
+                        deperated_vars.add(output_name)
 
         if not deperated_vars:
             # got no gradient_clip op
@@ -65,6 +71,7 @@ class GradientClipHelper(object):
                 continue
             reversed_inputs = []
             if op.type == "sum":
+                global_norm_sum_op_idx = idx
                 for input_name in op.desc.input_arg_names():
                     if input_name not in deperated_vars:
                         reversed_inputs.append(input_name)
@@ -86,20 +93,20 @@ class GradientClipHelper(object):
                         OP_ROLE_KEY: OpRole.Optimize,
                     })
 
-        # global norm should only be sum within each model parallelism word size when use global group
-        if pure_dp_degree > 1:
-            block._insert_op_without_sync(
-                idx + 2,
-                type='scale',
-                inputs={'X': sum_res},
-                outputs={'Out': sum_res},
-                attrs={
-                    'scale': 1.0 / float(pure_dp_degree),
-                    'op_namescope': "/gradient_clip_model_parallelism",
-                    'bias': 0.0,
-                    'bias_after_scale': False,
-                    OP_ROLE_KEY: OpRole.Optimize
-                })
+                # global norm should only be sum within each model parallelism word size when use global group
+                if pure_dp_degree > 1:
+                    block._insert_op_without_sync(
+                        idx + 2,
+                        type='scale',
+                        inputs={'X': sum_res},
+                        outputs={'Out': sum_res},
+                        attrs={
+                            'scale': 1.0 / float(pure_dp_degree),
+                            'op_namescope': "/gradient_clip_model_parallelism",
+                            'bias': 0.0,
+                            'bias_after_scale': False,
+                            OP_ROLE_KEY: OpRole.Optimize
+                        })
 
         # the grad sum here should take the all and only param in the current shard
         to_check_param = set(reversed_x_paramname)
@@ -115,3 +122,45 @@ class GradientClipHelper(object):
             block._remove_var(var_name, sync=False)
         block._sync_with_cpp()
         return
+
+    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+    def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
+        """
+        prune gradient_clip related ops for params that not belong to cur shard
+        prune: square, reduce_sum, elementwise_mul
+        keep: sum, sqrt, elementwise_max, elementwise_div
+        """
+        for idx, op in reversed(list(enumerate(block.ops))):
+            if not self._is_gradient_clip_op(op):
+                continue
+
+            if op.type == "sum":
+                sum_res = op.desc.output_arg_names()[0]
+                block._insert_op_without_sync(
+                    idx + 1,
+                    type='c_allreduce_sum',
+                    inputs={'X': sum_res},
+                    outputs={'Out': sum_res},
+                    attrs={
+                        'ring_id': ring_id,
+                        'op_namescope': "/gradient_clip_model_parallelism",
+                        'use_calc_stream': True,
+                        OP_ROLE_KEY: OpRole.Optimize,
+                    })
+
+                # global norm should only be sum within each model parallelism word size
+                if pure_dp_degree > 1:
+                    block._insert_op_without_sync(
+                        idx + 2,
+                        type='scale',
+                        inputs={'X': sum_res},
+                        outputs={'Out': sum_res},
+                        attrs={
+                            'scale': 1.0 / float(pure_dp_degree),
+                            'op_namescope': "/gradient_clip_model_parallelism",
+                            'bias': 0.0,
+                            'bias_after_scale': False,
+                            OP_ROLE_KEY: OpRole.Optimize
+                        })
+
+        return
diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
new file mode 100755
index 0000000000000000000000000000000000000000..76803818453c929d1dbf503159c59e1325c8337e
--- /dev/null
+++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
@@ -0,0 +1,281 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
+from paddle.fluid import core, unique_name
+
+
+class OffloadHelper(object):
+    cpu_place_type = 0
+    cuda_place_type = 1
+    cuda_pinned_place_type = 2
+
+    def __init__(self):
+        pass
+        "0: dst is on CPUPlace. "
+        "1: dst is on CUDAPlace. "
+        "2: dst is on CUDAPinnedPlace. "
+
+    def _insert_cast_op(self, block, idx, src_name, dst_name):
+        src_var = block.var(src_name)
+        if not block.has_var(dst_name):
+            block.create_var(
+                name=dst_name,
+                shape=src_var.shape,
+                dtype=core.VarDesc.VarType.FP16,
+                persistable=True)
+        dst_var = block.var(dst_name)
+        assert dst_var.dtype == core.VarDesc.VarType.FP16
+        block._insert_op_without_sync(
+            idx,
+            type='cast',
+            inputs={'X': src_var},
+            outputs={'Out': dst_var},
+            attrs={
+                'in_dtype': src_var.dtype,
+                'out_dtype': dst_var.dtype,
+                OP_ROLE_KEY: OpRole.Optimize
+            })
+
+    def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
+        src_var = block.var(src_name)
+        dst_var = block.var(dst_name)
+        block._insert_op_without_sync(
+            idx,
+            type='memcpy',
+            inputs={'X': src_var},
+            outputs={'Out': dst_var},
+            attrs={
+                'dst_place_type': dst_place_type,
+                OP_ROLE_KEY: OpRole.Optimize,
+            })
+
+    def _insert_fetch_op(self, block, idx, src_name, dst_name):
+        self._insert_memcpy_op(block, idx, src_name, dst_name,
+                               OffloadHelper.cuda_place_type)
+
+    def _insert_offload_op(self, block, idx, src_name, dst_name):
+        self._insert_memcpy_op(block, idx, src_name, dst_name,
+                               OffloadHelper.cuda_pinned_place_type)
+
+    def _get_offload_var_name(self, name):
+        return unique_name.generate(name + '@offload')
+
+    def _create_offload_var(self, var_name, offload_var_name, blocks):
+        for block in blocks:
+            var = block.var(var_name)
+            var.persistable = False
+            offload_var = block.create_var(
+                name=offload_var_name,
+                shape=var.shape,
+                dtype=var.dtype,
+                persistable=True)
+
+    def offload_fp32param(self, block, startup_block):
+        """
+        (p_fp16) = cast(p)
+        (p_fp16_recompute) = cast(p)
+        (pout,) = adam(p)
+        ===========================>
+        rename(p_fp16_recompute, p_fp16)
+
+        (p,) = prefetch(p@offload)
+        (pout,) = adam(p)
+        (p_fp16) = cast(p)
+        (p@offload) = memcpy(p)
+        """
+        param_to_idx = dict()
+        param_to_fp16 = dict()
+        # recompute_var which need rename to fp16_param
+        fp16_param_to_recompute = dict()
+        recompute_to_fp16 = dict()
+
+        def remove_param(input_name):
+            param_to_idx.pop(input_name)
+            if input_name in param_to_fp16:
+                fp16_param = param_to_fp16.pop(input_name)
+                if fp16_param in fp16_param_to_recompute:
+                    recompute = fp16_param_to_recompute.pop(fp16_param)
+                    recompute_to_fp16.pop(recompute)
+
+        # step1: record param
+        for idx, op in reversed(list(enumerate(block.ops))):
+            if op.type in ('adam', 'momentum', 'lars', 'lamb'):
+                param = op.desc.input("Param")[0]
+                param_to_idx[param] = idx
+
+        # step2: remove param which can't offload
+        for idx, op in enumerate(block.ops):
+            if is_optimizer_op(op):
+                break
+            for input_name in op.desc.input_arg_names():
+                if input_name not in param_to_idx:
+                    continue
+
+                # param is real used by fp32 op
+                if op.type != 'cast':
+                    remove_param(input_name)
+                    continue
+
+                # param is only used by cast op,
+                # which to cast fp32_param to fp16_param
+                output_name = op.output_arg_names[0]
+                if 'cast_fp16' not in output_name:
+                    remove_param(input_name)
+                    continue
+
+                if 'subprog' not in output_name:
+                    assert output_name == input_name + '.cast_fp16'
+                    assert input_name not in param_to_fp16, \
+                        "There must be only one cast op from fp32 param to fp16 param."
+                    param_to_fp16[input_name] = output_name
+                else:
+                    # fp16-->recompute_var
+                    assert input_name in param_to_fp16, \
+                        "param must first be cast to fp16"
+                    fp16_param = param_to_fp16[input_name]
+                    fp16_param_to_recompute[fp16_param] = output_name
+                    recompute_to_fp16[output_name] = fp16_param
+
+        param_name_to_offload_name = dict()
+        # step3: main_block add offload, cast op
+        # change recompute to fp16, remove cast(param) to fp16
+        for idx, op in reversed(list(enumerate(block.ops))):
+            if op.type in ('adam', 'momentum', 'lars', 'lamb'):
+                param = op.desc.input("Param")[0]
+                if param not in param_to_idx: continue
+                # step3.1: create offload_var
+                offload_var_name = self._get_offload_var_name(param)
+                param_name_to_offload_name[param] = offload_var_name
+                self._create_offload_var(param, offload_var_name,
+                                         [block, startup_block])
+
+                # step3.2: insert cast op and offload op
+                self._insert_offload_op(block, idx + 1, param, offload_var_name)
+
+                assert param in param_to_fp16
+                fp16_param_name = param_to_fp16[param]
+                fp16_param_var = block.var(fp16_param_name)
+                fp16_param_var.persistable = True
+                self._insert_cast_op(block, idx + 1, param,
+                                     param_to_fp16[param])
+
+                # step3.3: insert fetch op
+                self._insert_fetch_op(block, idx, offload_var_name, param)
+                continue
+
+            # step3.4: remove cast op
+            if op.type == 'cast':
+                input_name = op.desc.input_arg_names()[0]
+                if input_name in param_to_idx:
+                    block._remove_op(idx, sync=False)
+                    continue
+
+            # step3.5: change recompute_param to fp16_param
+            for input_name in op.desc.input_arg_names():
+                if input_name in recompute_to_fp16:
+                    op._rename_input(input_name, recompute_to_fp16[input_name])
+            for output_name in op.desc.output_arg_names():
+                if output_name in recompute_to_fp16:
+                    op._rename_output(output_name,
+                                      recompute_to_fp16[output_name])
+
+        # step4: remove recompute_param
+        for name in recompute_to_fp16.keys():
+            block._remove_var(name, sync=False)
+
+        # step5: startup_block add offload
+        visited_vars = set()
+        for idx, op in reversed(list(enumerate(startup_block.ops))):
+            for out_name in op.output_arg_names:
+                if out_name in visited_vars:
+                    continue
+
+                if out_name in param_name_to_offload_name:
+                    var_name = out_name
+                    offload_var_name = param_name_to_offload_name[var_name]
+                    self._insert_offload_op(startup_block, idx + 1, var_name,
+                                            offload_var_name)
+                    self._insert_cast_op(startup_block, idx + 1, var_name,
+                                         param_to_fp16[var_name])
+
+                visited_vars.add(out_name)
+
+        block._sync_with_cpp()
+        startup_block._sync_with_cpp()
+
+    def offload(self, block, startup_block):
+        """
+        (m1, m2) = prefetch(m1@offload, m2@offload)
+        (m1out, m2out, pout) = adam(m1, m2, p)
+        (m1@offload, m2@offload) = memcpy(m1, m2)
+        """
+        vars_name_to_offload_name = dict()
+
+        # main_block add offload
+        for idx, op in reversed(list(enumerate(block.ops))):
+            if not is_optimizer_op(op):
+                break
+
+            vars_name = []
+            if op.type == "adam":
+                # {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
+                # adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
+                vars_name.append(op.desc.input("Moment1")[0])
+                vars_name.append(op.desc.input("Moment2")[0])
+            elif op.type == 'momentum':
+                pass
+            elif op.type == 'lars':
+                pass
+            elif op.type == 'lamb':
+                pass
+
+            # step1: create and init offload_var
+            for var_name in vars_name:
+                assert var_name not in vars_name_to_offload_name
+
+                offload_var_name = self._get_offload_var_name(var_name)
+                vars_name_to_offload_name[var_name] = offload_var_name
+
+                self._create_offload_var(var_name, offload_var_name,
+                                         [block, startup_block])
+
+            # step2: insert offload op
+            for var_name in vars_name:
+                offload_var_name = vars_name_to_offload_name[var_name]
+                self._insert_offload_op(block, idx + 1, var_name,
+                                        offload_var_name)
+
+            # step3: insert fetch op
+            for var_name in vars_name:
+                offload_var_name = vars_name_to_offload_name[var_name]
+                self._insert_fetch_op(block, idx, offload_var_name, var_name)
+
+        # startup_block add offload
+        visited_vars = set()
+        for idx, op in reversed(list(enumerate(startup_block.ops))):
+            for out_name in op.output_arg_names:
+                if out_name in visited_vars:
+                    continue
+
+                if out_name in vars_name_to_offload_name:
+                    var_name = out_name
+                    offload_var_name = vars_name_to_offload_name[var_name]
+                    # insert offload op after var is generated
+                    self._insert_offload_op(startup_block, idx + 1, var_name,
+                                            offload_var_name)
+                visited_vars.add(out_name)
+
+        block._sync_with_cpp()
+        startup_block._sync_with_cpp()
diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
old mode 100644
new mode 100755
index 70753b59ccc318a25661e084bd305d7d76b0e2a6..5a43367cf1ad123501883d93fffbcf096db8b66f
--- a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
@@ -126,6 +126,10 @@ class ProgramDeps(object):
 
     def should_remove_op(self, op_idx):
         op = self._block.ops[op_idx]
+        # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+        # remove check_finite_and_unscale op if its input 'X' is empty
+        if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
+            return True
         for output_name in op.desc.output_arg_names():
             if output_name not in self._should_removed_var:
                 return False
diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
index 8b111026bdb916adcd4eaebc3e45c1876fd07e90..f4ceb2d287a56c7b955817263f751e32dbf23e77 100755
--- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
     """
     insert sync_comm_op for vars
     """
+    # NOTE (JZ-LIANG) to be check, may result undefined case 
+    if len(comm_dep_vars) == 0:
+        return 0
+
     op_role = get_valid_op_role(block, insert_idx)
     block._insert_op_without_sync(
         insert_idx,
@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
     return
 
 
-def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
+def insert_allreduce_ops(block,
+                         insert_idx,
+                         ring_id,
+                         allreduce_vars,
+                         op_role=OpRole.Backward,
+                         use_calc_stream=False):
     """
     _add_allreduce_ops
     """
+    if len(allreduce_vars) == 0:
+        return
+
     for var in allreduce_vars:
         block._insert_op_without_sync(
             insert_idx,
             type='c_allreduce_sum',
             inputs={'X': var},
             outputs={'Out': var},
-            attrs={'ring_id': ring_id,
-                   OP_ROLE_KEY: OpRole.Backward})
+            attrs={
+                'ring_id': ring_id,
+                'use_calc_stream': use_calc_stream,
+                OP_ROLE_KEY: op_role
+            })
 
     return
 
 
-def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
+def insert_reduce_ops(block,
+                      insert_idx,
+                      ring_id,
+                      reduce_vars,
+                      shard,
+                      op_role=OpRole.Backward,
+                      use_calc_stream=False):
     """
     _add_allreduce_ops
     """
     for var in reduce_vars:
+
         root_id = get_grad_device(var, shard)
         assert root_id >= 0, "root id should be a positive int".format(var)
         block._insert_op_without_sync(
@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
             attrs={
                 'ring_id': ring_id,
                 'root_id': root_id,
-                OP_ROLE_KEY: OpRole.Backward
+                'use_calc_stream': use_calc_stream,
+                OP_ROLE_KEY: op_role
             })
-
     return
 
 
+def get_grad_device(grad_name, shard):
+    assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
+        grad_name)
+    base_name = None
+    # mind the traversal order 
+    possible_suffixes = [
+        '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
+    ]
+    for suffix in possible_suffixes:
+        if suffix in grad_name:
+            base_name = re.sub(suffix, '', grad_name)
+            break
+
+    assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
+        base_name)
+
+    return shard.global_param2device[base_name]
+
+
+def get_first_check_finite_and_unscale_op_idx(block):
+
+    for idx, op in enumerate(block.ops):
+        if op.type == "check_finite_and_unscale":
+            return idx
+
+    raise ValueError("check_finite_and_unscale does not exist in block")
+
+
 def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
     """
     _add_broadcast_ops
@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
                 outputs={'Out': loss_grad_var},
                 attrs={'scale': scale,
                        OP_ROLE_KEY: OpRole.Backward})
+            break
 
 
 def comm_analyse(main_program):
@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
     and part of persistable vars are duplicated and exist in all the ranks with different values.
     This function handles the model saving for sharding training.
     """
+    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
+    if main_program._pipeline_opt:
+        main_program = main_program._pipeline_opt['section_program']['program']
 
     def is_opt_vars(var):
         # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
index cf3f75740ee3ddc0dde97b9aa1df861dfa1067b4..a83ae226a9df1eeec1239881028893278412c44c 100755
--- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
@@ -16,16 +16,16 @@ import paddle
 from paddle.fluid import unique_name, core
 import paddle.fluid as fluid
 from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
-from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
+from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
 from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
 from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
 from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
 from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
 from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
+from .sharding.offload_helper import OffloadHelper
 from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
 from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
 from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
-
 from paddle.fluid import layers
 
 import logging
@@ -38,6 +38,8 @@ __all__ = ["ShardingOptimizer"]
 
 
 class ShardingOptimizer(MetaOptimizerBase):
+    """Sharding Optimizer."""
+
     def __init__(self, optimizer):
         super(ShardingOptimizer, self).__init__(optimizer)
         self.inner_opt = optimizer
@@ -46,7 +48,8 @@ class ShardingOptimizer(MetaOptimizerBase):
             "AMPOptimizer",
             "LarsOptimizer",
             "LambOptimizer",
-            "ModelParallelOptimizer",
+            # "ModelParallelOptimizer",
+            # "PipelineOptimizer",
         ]
         self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
         self._main_program = None
@@ -88,26 +91,6 @@ class ShardingOptimizer(MetaOptimizerBase):
         self._nrings_sharding = 1
         self._nrings_dp = 1
 
-        # parallelism
-        self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
-            "sharding_degree"])
-        assert self.sharding_degree > 1, "sharding degree must be larger than zero"
-        self.mp_degree = int(self.user_defined_strategy.sharding_configs[
-            "mp_degree"])
-        self.hybrid_dp = self.user_defined_strategy.sharding_configs[
-            "hybrid_dp"]
-
-        self.pp_degree = 1
-
-        # dp here is the pure dp as the outest parallelism
-        self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree //
-                             self.sharding_degree)
-        assert self.role_maker._worker_num(
-        ) == self.dp_degree * self.mp_degree * self.sharding_degree * self.pp_degree
-        if self.hybrid_dp:
-            assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
-                self.dp_degree)
-
         # segment
         self._sharding_segment_strategy = str(
             self.user_defined_strategy.sharding_configs[
@@ -128,55 +111,231 @@ class ShardingOptimizer(MetaOptimizerBase):
                 "the sharding segment strategy [{}] is not implemented".format(
                     str(self._sharding_segment_strategy)))
 
+        # parallelism
+        self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
+            "sharding_degree"])
+        assert self.sharding_degree > 0, "sharding degree must be larger than zero"
+        self.mp_degree = int(self.user_defined_strategy.sharding_configs[
+            "mp_degree"])
+        # pipeline setting
+        # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
+        self.pp_degree = int(self.user_defined_strategy.sharding_configs[
+            "pp_degree"])
+        if self.pp_degree > 1:
+            assert self.user_defined_strategy.pipeline == True
+
+        self.dp_degree = int(self.user_defined_strategy.sharding_configs[
+            'dp_degree'])
+        assert self.role_maker._worker_num(
+        ) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
+            self.role_maker._worker_num(),
+            self.mp_degree,
+            self.sharding_degree,
+            self.pp_degree,
+            self.dp_degree, )
+
+        self.hybrid_dp = self.user_defined_strategy.sharding_configs[
+            "hybrid_dp"]
+        # NOTE (JZ-LIANG) 
+        # there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline].
+        # we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance:
+        # sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step 
+        # pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step        
+        self.hybrid_dp_mode = None
+        # dp here is the pure dp as the outest parallelism
+        if self.hybrid_dp:
+            assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
+                self.dp_degree)
+            if self.pp_degree > 1:
+                self.hybrid_dp_mode = "pp_hybrid_dp"
+            else:
+                assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
+                self.hybrid_dp_mode = "sharding_hybrid_dp"
+
         # gradient merge
         self._gradient_merge_acc_step = int(
             self.user_defined_strategy.sharding_configs[
                 "gradient_merge_acc_step"])
-        self._grad2merged_grad = dict()
+        self.gradient_merge_mode = None
+        if self.pp_degree <= 1:
+            self.gradient_merge_mode = "sharding_gm"
+            self._grad2merged_grad = dict()
+        else:
+            self.gradient_merge_mode = "pp_gm"
+            self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
+                'accumulate_steps']
+        if self._gradient_merge_acc_step > 1:
+            logging.info("Gradient merge in [{}], acc step = [{}]".format(
+                self.gradient_merge_mode, self._gradient_merge_acc_step))
+
+        # optimize offload
+        self.optimize_offload = self.user_defined_strategy.sharding_configs[
+            "optimize_offload"]
+
+        # this feature is design for ascend, and should NOT be used in GPU training
+        self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
+            "pp_allreduce_in_optimize"]
 
         if self.inner_opt is None:
             raise ValueError(
                 "self.inner_opt of ShardingOptimizer should not be None.")
-        optimize_ops, params_grads = self.inner_opt.minimize(
-            loss, startup_program, parameter_list, no_grad_set)
+
+        if self.pp_degree > 1:
+            pp_optimizer = fluid.optimizer.PipelineOptimizer(
+                self.inner_opt, self._gradient_merge_acc_step)
+            main_program = loss.block.program
+            main_program._pipeline_opt = dict()
+            self.schedule_mode = self.user_defined_strategy.pipeline_configs[
+                'schedule_mode']
+            main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
+            main_program._pipeline_opt[
+                'micro_batch_size'] = self.user_defined_strategy.pipeline_configs[
+                    'micro_batch_size']
+            self.pp_rank_ = self.role_maker._worker_index() // (
+                self.sharding_degree * self.mp_degree) % self.pp_degree
+            main_program._pipeline_opt['local_rank'] = self.pp_rank_
+            main_program._pipeline_opt[
+                'global_rank'] = self.role_maker._worker_index()
+            main_program._pipeline_opt['use_sharding'] = True
+            # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
+            main_program._pipeline_opt['ring_id'] = 20
+            main_program._pipeline_opt['global_ring_id'] = 3
+
+            optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
+                loss, startup_program, parameter_list, no_grad_set)
+            self.pp_degree = len(program_list)
+        else:
+            optimize_ops, params_grads = self.inner_opt.minimize(
+                loss, startup_program, parameter_list, no_grad_set)
 
         if startup_program is None:
             startup_program = default_startup_program()
-        main_block = loss.block
+
+        if self.pp_degree > 1:
+            startup_program = startup_program._pipeline_opt['startup_program']
+            #main_program = main_program._pipeline_opt['section_program']['program']
+            print("pp_rank:", self.pp_rank_)
+            main_program = program_list[self.pp_rank_]
+            with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
+                f.writelines(str(main_program))
+            main_block = main_program.global_block()
+            new_params_grads = []
+            for param, grad in params_grads:
+                if main_block.has_var(param.name):
+                    new_params_grads.append((param, grad))
+            params_grads = new_params_grads
+
+        else:
+            main_block = loss.block
+
         startup_block = startup_program.global_block()
         self._main_program = main_block.program
         self._startup_program = startup_program
 
+        if self.pp_degree > 1:
+            pp_optimizer._rename_gradient_var_name(main_block)
+            with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
+                f.writelines(str(main_program))
+
         # step0: _init_comm
         self._init_comm()
 
-        # step1: _build_shard
-        self._build_shard(params_grads)
-
-        # step2: split_program
-        self._split_program(main_block)
+        if self.sharding_degree > 1:
 
-        # step3: add broadcast and reduce ops
-        self._add_broadcast_allreduce(main_block)
-        main_block._sync_with_cpp()
-        startup_block._sync_with_cpp()
+            # step1: build shard
+            self._build_shard(params_grads)
+
+            # step2: split_program
+            self._split_program(main_block)
+
+            # step3: add broadcast and reduce ops
+            self._add_broadcast_allreduce(main_block)
+            main_block._sync_with_cpp()
+            startup_block._sync_with_cpp()
+
+            main_block._sync_with_cpp()
+
+            # step4: remove unneeded ops and vars from block
+            self._prune_main_program(main_block)
+            self._prune_startup_program(startup_block)
+
+        if self.pp_degree > 1:
+            # sharding-pp related logic
+            # pp_optimizer._rename_gradient_var_name(main_block)
+            # crop ops
+            if self.sharding_degree > 1:
+                for idx, op in reversed(list(enumerate(main_block.ops))):
+                    if is_update_op(op):
+                        op_role_var = op.attr('op_role_var')
+                        param_name = op_role_var[0]
+                        if not self._shard.has_param(param_name):
+                            main_block._remove_op(idx)
+
+                for idx, op in reversed(list(enumerate(main_block.ops))):
+                    if op.type != 'cast': continue
+                    in_name = op.input_arg_names[0]
+                    if in_name not in self._params: continue
+                    #if self._shard.has_param(param_name): continue
+                    if in_name not in main_block.vars:
+                        main_block._remove_op(idx)
+
+            accumulated_grad_names = pp_optimizer._accumulate_gradients(
+                main_block)
+            # accumulated_grad_names = sorted(accumulated_grad_names)
+            if self.pp_allreduce_in_optimize:
+                print("persistable FP32 grad: ")
+                print(accumulated_grad_names)
+                first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
+                    main_block)
+                insert_reduce_ops(
+                    main_block,
+                    first_optimize_op_index,
+                    self.sharding_ring_id,
+                    accumulated_grad_names,
+                    self._shard,
+                    core.op_proto_and_checker_maker.OpRole.Optimize,
+                    use_calc_stream=True)
+            if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
+                first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
+                    main_block)
+                insert_allreduce_ops(
+                    main_block,
+                    first_optimize_op_index,
+                    self.dp_ring_id,
+                    accumulated_grad_names,
+                    core.op_proto_and_checker_maker.OpRole.Optimize,
+                    use_calc_stream=True)
+
+        # if not use sharding, adapt amp/clip, for remain parallelism.
+        # cast --> amp --> clip --> opt
+        if self.sharding_degree <= 1:
+            # amp
+            FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id)
+
+            # clip
+            gradientclip_helper = GradientClipHelper(self.global_ring_id)
+            gradientclip_helper.sync_global_norm(
+                main_block, self.global_ring_id, self.dp_degree)
 
-        # step4: scale the loss by the num of dp degree
-        # sharding is also a senario of dp
-        scale_ = self.dp_degree * self.sharding_degree
-        if scale_ > 1:
-            insert_scale_loss_grad_ops(main_block, scale=1.0 / scale_)
+        # step6: loss div dp_degree 
+        global_dp_degree = self.sharding_degree * self.dp_degree
+        assert int(global_dp_degree) == global_dp_degree
+        if global_dp_degree > 1:
+            insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree)
 
         main_block._sync_with_cpp()
 
-        # step5: remove unneeded ops and vars from block
-        self._prune_main_program(main_block)
-        self._prune_startup_program(startup_block)
-        if self.hybrid_dp:
-            self._initialization_broadcast(startup_program)
-
-        # step6: optional gradient merge
-        if self._gradient_merge_acc_step > 1:
+        # TODO(wangxi): add optimize offload
+        # opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100) 
+        # sync its memcpy could not be overlap with calc, otherwise it will slower down training severely. 
+        if self.optimize_offload:
+            logging.info("Sharding with optimize offload !")
+            offload_helper = OffloadHelper()
+            offload_helper.offload(main_block, startup_block)
+            offload_helper.offload_fp32param(main_block, startup_block)
+
+        # step6: (optional) sharding gradient merge
+        if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
             self._sharding_gradient_merge(main_block)
 
         # # check op dependecy
@@ -184,14 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase):
         # check_broadcast(main_block)
         # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
         #                     self.dp_ring_id)
+
+        if self.hybrid_dp:
+            # NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp 
+            # init param broadcast should be called after startup pruning             
+            self._initialization_broadcast(startup_block)
+
+        with open("start_sharding_%d" % self.role_maker._worker_index(),
+                  'w') as f:
+            f.writelines(str(startup_block.program))
+        with open("main_sharding_%d" % self.role_maker._worker_index(),
+                  'w') as f:
+            f.writelines(str(main_block.program))
+
         self._wait()
 
         return optimize_ops, params_grads
 
     def _init_comm(self):
+
         # config sharding & dp groups
-        self._build_group()
+        self._build_groups()
 
+        # sync var
         startup_block = self._startup_program.global_block()
         self.startup_prog_sync_var = startup_block.create_var(
             name="startup_prog_sync_var",
@@ -199,7 +373,7 @@ class ShardingOptimizer(MetaOptimizerBase):
             dtype=core.VarDesc.VarType.INT32,
             persistable=False)
 
-        # global
+        # global ring
         self._collective_helper._init_communicator(
             self._startup_program,
             self.current_endpoint,
@@ -212,7 +386,7 @@ class ShardingOptimizer(MetaOptimizerBase):
         append_naive_sync(startup_block, self.startup_prog_sync_var,
                           self.global_ring_id)
 
-        # mp
+        # mp ring
         if self.mp_degree > 1:
             self._collective_helper._init_communicator(
                 self._startup_program,
@@ -226,7 +400,7 @@ class ShardingOptimizer(MetaOptimizerBase):
             append_naive_sync(startup_block, self.startup_prog_sync_var,
                               self.global_ring_id)
 
-        # sharding
+        # sharding ring
         if self.sharding_degree > 1:
             self._collective_helper._init_communicator(
                 self._startup_program,
@@ -240,7 +414,65 @@ class ShardingOptimizer(MetaOptimizerBase):
             append_naive_sync(startup_block, self.startup_prog_sync_var,
                               self.global_ring_id)
 
-        # dp
+        # pp ring
+        if self.pp_degree > 1:
+            if self.schedule_mode == 'F-then-B':  # GPipe
+                self._collective_helper._init_communicator(
+                    self._startup_program,
+                    self.current_endpoint,
+                    self.pp_group_endpoints,
+                    self.pp_rank,
+                    self.pp_ring_id,
+                    False,
+                    global_ring_id=self.global_ring_id,
+                    sync=False)
+                # append_naive_sync(startup_block, self.startup_prog_sync_var,
+                #                   self.global_ring_id)
+                self._collective_helper._init_communicator(
+                    self._startup_program,
+                    self.current_endpoint,
+                    self.pp_group_endpoints,
+                    self.pp_rank,
+                    self.pp_ring_id + 2,
+                    False,
+                    global_ring_id=self.global_ring_id,
+                    sync=False)
+                # append_naive_sync(startup_block, self.startup_prog_sync_var,
+                #                   self.global_ring_id)
+            else:
+                assert self.schedule_mode == '1F1B'
+                for pair in self.pipeline_pair:
+                    pair_key = pair[0] * 1000 + pair[1]
+                    ring_id = self.pp_ring_map[pair_key]
+                    print("pp pair:{}, ring_id: {}".format(pair, ring_id))
+                    if self.pp_rank not in pair: continue
+                    pp_group_endpoints = [
+                        self.pp_group_endpoints[pair[0]],
+                        self.pp_group_endpoints[pair[1]],
+                    ]
+                    if pair[0] < pair[1]:
+                        start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
+                    else:
+                        start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
+                            1] - 1
+                    pp_rank = 0 if self.pp_rank == pair[0] else 1
+                    self._collective_helper._init_communicator(
+                        self._startup_program,
+                        self.current_endpoint,
+                        pp_group_endpoints,
+                        pp_rank,
+                        ring_id,
+                        False,
+                        global_ring_id=self.global_ring_id,
+                        sync=False)
+                    # append_naive_sync(startup_block, self.startup_prog_sync_var,
+                    #                   self.global_ring_id)
+
+                # TODO (JZ-LIANG) to unify this shit 
+            assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
+                self.pp_rank_, self.pp_rank)
+
+        # pure dp ring
         if self.dp_degree > 1:
             self._collective_helper._init_communicator(
                 self._startup_program,
@@ -360,17 +592,22 @@ class ShardingOptimizer(MetaOptimizerBase):
                     self._main_program.global_block().var(input_name))
 
             # find reduce vars
-            if is_backward_op(op) and \
-                    OP_ROLE_VAR_KEY in op.attr_names:
-                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
-                if len(op_role_var) != 0:
-                    assert len(op_role_var) % 2 == 0
-                    for i in range(0, len(op_role_var), 2):
-                        param, reduced_grad = op_role_var[i], op_role_var[i + 1]
-                        segment._allreduce_vars.append(reduced_grad)
-                        assert (
-                            reduced_grad not in self._reduced_grads_to_param)
-                        self._reduced_grads_to_param[reduced_grad] = param
+            if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
+                # place pipeline gradient allreduce in optimize
+                pass
+            else:
+                if is_backward_op(op) and \
+                        OP_ROLE_VAR_KEY in op.attr_names:
+                    op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
+                    if len(op_role_var) != 0:
+                        assert len(op_role_var) % 2 == 0
+                        for i in range(0, len(op_role_var), 2):
+                            param, reduced_grad = op_role_var[i], op_role_var[
+                                i + 1]
+                            segment._allreduce_vars.append(reduced_grad)
+                            assert (reduced_grad not in
+                                    self._reduced_grads_to_param)
+                            self._reduced_grads_to_param[reduced_grad] = param
 
             # find cast op
             if FP16Utils.is_fp16_cast_op(block, op, self._params):
@@ -462,8 +699,13 @@ class ShardingOptimizer(MetaOptimizerBase):
         # Prune
         for idx, op in reversed(list(enumerate(block.ops))):
             if op.type in [
-                    "c_allreduce_sum", "c_sync_comm_stream",
-                    "c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init"
+                    "c_allreduce_sum",
+                    "c_sync_comm_stream",
+                    "c_calc_comm_stream",
+                    "c_gen_nccl_id",
+                    "c_comm_init",
+                    'send_v2',
+                    'recv_v2',
             ]:
                 pass
             elif op.type == "conditional_block":
@@ -500,6 +742,16 @@ class ShardingOptimizer(MetaOptimizerBase):
                 if program_deps.should_remove_op(idx):
                     program_deps.remove_op(idx)
 
+        # NOTE (JZ-LIANG) revise and unify logic here
+        # sharding support fp16_allreduce logic            
+        block._sync_with_cpp()
+        for idx, op in reversed(list(enumerate(block.ops))):
+            if op.type == 'concat' and is_optimizer_op(op):
+                # remove inputs that not on this card
+                reserved_x = []
+                for var_name in op.desc.input("X"):
+                    if block.has_var(var_name): reserved_x.append(var_name)
+                op.desc.set_input('X', reserved_x)
         block._sync_with_cpp()
         return
 
@@ -507,21 +759,41 @@ class ShardingOptimizer(MetaOptimizerBase):
         """
         add broadcast allreduce op
         if enable gradient_merge, insert related ops
+
+        if combined with pipeline(grad accumulate), 
+        the grad allreduce should be done in optimize role
         """
         if len(self._segments) < 1:
             return
         # sharding
+        if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
+            for idx in range(len(self._segments)):
+                assert len(self._segments[idx]._allreduce_vars) == 0
+
+        # NOTE (JZ-LIANG) revise and unify logic here
+        # fix the _end_idx for segments[-1] if pp is used.
+        new_end_idx = self._segments[-1]._end_idx
+        for idx in range(self._segments[-1]._end_idx - 1,
+                         self._segments[-1]._start_idx - 1, -1):
+            op = block.ops[idx]
+            if op.type == "fill_constant" or op.type == "sum":
+                if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1
+            elif op.type == "cast":
+                if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1
+        self._segments[-1]._end_idx = new_end_idx
+
         if self._segments[-1]._allreduce_vars:
             shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
                                                            ._allreduce_vars)
-            if self._gradient_merge_acc_step <= 1:
-                if self.hybrid_dp and len(shard_allredue_vars) >= 1:
+            if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
+                if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
+                        shard_allredue_vars) >= 1:
                     insert_sync_comm_ops(block, self._segments[-1]._end_idx,
                                          self.dp_ring_id, shard_allredue_vars)
                     insert_allreduce_ops(block, self._segments[-1]._end_idx,
                                          self.dp_ring_id, shard_allredue_vars)
             # gradient merge 
-            else:
+            elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
                 self.create_persistable_gradients_and_insert_merge_ops(
                     block,
                     self._startup_program.global_block(),
@@ -532,9 +804,14 @@ class ShardingOptimizer(MetaOptimizerBase):
                                  self.sharding_ring_id,
                                  self._segments[-1]._allreduce_vars)
             # allreduce --> reduce 
-            insert_reduce_ops(block, self._segments[-1]._end_idx,
-                              self.sharding_ring_id,
-                              self._segments[-1]._allreduce_vars, self._shard)
+            insert_reduce_ops(
+                block,
+                self._segments[-1]._end_idx,
+                self.sharding_ring_id,
+                self._segments[-1]._allreduce_vars,
+                self._shard,
+                op_role=OpRole.Backward,
+                use_calc_stream=False)
 
         for idx, segment in reversed(list(enumerate(self._segments))):
             allreduce_vars = self._segments[
@@ -574,8 +851,9 @@ class ShardingOptimizer(MetaOptimizerBase):
             # step2: add Sync ops
             shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
 
-            if self._gradient_merge_acc_step <= 1:
-                if self.hybrid_dp and len(shard_allredue_vars) >= 1:
+            if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
+                if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
+                        shard_allredue_vars) >= 1:
                     insert_sync_comm_ops(block, segment._end_idx,
                                          self.dp_ring_id, shard_allredue_vars)
 
@@ -593,7 +871,7 @@ class ShardingOptimizer(MetaOptimizerBase):
                                              self.sharding_ring_id,
                                              comm_dep_vars)
             # gradient merge
-            else:
+            elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
                 broad_cast_vars = [x[0] for x in broadcast_vars]
                 if len(broad_cast_vars) > 0:
                     insert_sync_comm_ops(block, segment._end_idx,
@@ -616,7 +894,7 @@ class ShardingOptimizer(MetaOptimizerBase):
 
             # step5: add broadcast ops
             # gradient merge
-            if self._gradient_merge_acc_step > 1:
+            if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
                 self.create_persistable_gradients_and_insert_merge_ops(
                     block,
                     self._startup_program.global_block(), segment._start_idx,
@@ -627,20 +905,29 @@ class ShardingOptimizer(MetaOptimizerBase):
 
             # step6: add all_reduce ops
             # dp
-            if self._gradient_merge_acc_step <= 1:
-                if self.hybrid_dp and len(shard_allredue_vars) >= 1:
+            if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
+                if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
+                        shard_allredue_vars) >= 1:
                     insert_allreduce_ops(block, segment._start_idx,
                                          self.dp_ring_id, shard_allredue_vars)
                     insert_sync_comm_ops(block, segment._start_idx,
                                          self.sharding_ring_id, allreduce_vars)
             # gradient merge
-            else:
+            elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
                 insert_sync_comm_ops(block, segment._start_idx,
                                      self.sharding_ring_id, allreduce_vars)
             # sharding
             # allreduce --> reduce 
-            insert_reduce_ops(block, segment._start_idx, self.sharding_ring_id,
-                              allreduce_vars, self._shard)
+            # TODO temp change
+            if len(allreduce_vars) > 0:
+                insert_reduce_ops(
+                    block,
+                    segment._start_idx,
+                    self.sharding_ring_id,
+                    allreduce_vars,
+                    self._shard,
+                    op_role=OpRole.Backward,
+                    use_calc_stream=False)
 
             block._sync_with_cpp()
 
@@ -691,14 +978,14 @@ class ShardingOptimizer(MetaOptimizerBase):
             block._remove_var(var_name, sync=False)
         block._sync_with_cpp()
 
-    def _build_group(self):
+    def _build_groups(self):
         """
         pre-assign ring ids
-        mp: 0
-        sharding: 1
-        pure-dp: 2
-        global: 3
-        pp: >= 20
+            mp: 0
+            sharding: 1
+            pure-dp: 2
+            global: 3
+            pp: >= 20
         if one parallelism is not enable: -1
         and only support parallelism hierarchy: mp --> sharding --> pp --> dp        
         """
@@ -768,6 +1055,30 @@ class ShardingOptimizer(MetaOptimizerBase):
             self.sharding_group_id = -1
             self.sharding_group_endpoints = []
 
+        # pp
+        if self.pp_degree > 1:
+            self.pp_ring_id = 20
+            self.pp_rank = self.global_rank // (self.sharding_degree *
+                                                self.mp_degree) % self.pp_degree
+            # (NOTE): Already adjust for (outter-pure) dp
+            self.pp_group_id = self.global_rank // (
+                self.mp_degree * self.sharding_degree * self.pp_degree)
+            pp_first_stage_idx = self.global_rank % (
+                self.sharding_degree * self.mp_degree) + self.pp_group_id * (
+                    self.mp_degree * self.sharding_degree * self.pp_degree)
+            pp_stage_offset = self.sharding_degree * self.mp_degree
+            self.pp_group_endpoints = []
+            for i in range(self.pp_degree):
+                self.pp_group_endpoints.append(self.global_endpoints[
+                    pp_first_stage_idx + pp_stage_offset * i])
+            assert self.current_endpoint in self.pp_group_endpoints
+        else:
+            self.pp_degree = 1
+            self.pp_ring_id = -1
+            self.pp_rank = -1
+            self.pp_group_id = -1
+            self.pp_group_endpoints = []
+
         # outter-pure-dp group
         # NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism
         # e.g. mp-sharding-pp-dp
@@ -775,6 +1086,7 @@ class ShardingOptimizer(MetaOptimizerBase):
         assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
             self.mp_degree, self.sharding_degree, self.pp_degree,
             self.dp_degree, self.global_word_size)
+
         if self.dp_degree > 1:
             self.dp_ring_id = 2
             self.dp_rank = self.global_rank // (self.sharding_degree *
@@ -794,6 +1106,8 @@ class ShardingOptimizer(MetaOptimizerBase):
             self.dp_group_endpoints = []
 
         # global group
+        # use for gen_nccl_comm_sync, amp check nan inf, clip by global norm
+        # NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
         self.global_ring_id = 3
 
         logging.info("global word size: {}".format(self.global_word_size))
@@ -817,25 +1131,31 @@ class ShardingOptimizer(MetaOptimizerBase):
         logging.info("sharding ring id: {}".format(self.sharding_ring_id))
         logging.info("#####" * 6)
 
-        logging.info("outter pure dp group size: {}".format(self.dp_degree))
-        logging.info("outter pure dp rank: {}".format(self.dp_rank))
-        logging.info("outter pure dp group endpoints: {}".format(
+        logging.info("pp group size: {}".format(self.pp_degree))
+        logging.info("pp rank: {}".format(self.pp_rank))
+        logging.info("pp group id: {}".format(self.pp_group_id))
+        logging.info("pp group endpoints: {}".format(self.pp_group_endpoints))
+        logging.info("pp ring id: {}".format(self.pp_ring_id))
+        logging.info("#####" * 6)
+
+        logging.info("pure dp group size: {}".format(self.dp_degree))
+        logging.info("pure dp rank: {}".format(self.dp_rank))
+        logging.info("pure dp group endpoints: {}".format(
             self.dp_group_endpoints))
-        logging.info("outter pure dp ring id: {}".format(self.dp_ring_id))
+        logging.info("pure dp ring id: {}".format(self.dp_ring_id))
         logging.info("#####" * 6)
 
         return
 
-    def _initialization_broadcast(self, startup_prog):
+    def _initialization_broadcast(self, startup_block):
         """
         this funtion is to ensure the initialization between dp group to be 
         identical when hybrid-dp is used.
         """
-        block = startup_prog.global_block()
         params = []
-        for param in block.iter_parameters():
+        for param in startup_block.iter_parameters():
             params.append(param)
-            block.append_op(
+            startup_block.append_op(
                 type='c_broadcast',
                 inputs={'X': param},
                 outputs={'Out': param},
@@ -844,15 +1164,14 @@ class ShardingOptimizer(MetaOptimizerBase):
                     'root': 0,
                     OP_ROLE_KEY: OpRole.Forward
                 })
-        block.append_op(
+        startup_block.append_op(
             type='c_sync_comm_stream',
             inputs={'X': params},
             outputs={'Out': params},
             attrs={'ring_id': self.dp_ring_id,
                    OP_ROLE_KEY: OpRole.Forward})
-
         # sync within global group
-        append_naive_sync(block, self.startup_prog_sync_var,
+        append_naive_sync(startup_block, self.startup_prog_sync_var,
                           self.global_ring_id)
 
     # sharding gradient merge
diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py
index b3a1834d49d3b3cd28ae44ba89d4cb93c9ae1d1c..572ebb26d73cb435aaa1fb2d69b059511c193818 100755
--- a/python/paddle/fluid/backward.py
+++ b/python/paddle/fluid/backward.py
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
             new_op_desc = block.desc.append_op()
             new_op_desc.copy_from(desc)
             new_op_desc._set_attr(op_role_attr_name, backward)
+            if desc.has_attr('op_device'):
+                new_op_desc._set_attr('op_device', desc.attr('op_device'))
             result_descs.append(new_op_desc)
     return result_descs
 
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
         new_op_desc = block.desc.append_op()
         new_op_desc.copy_from(desc)
         new_op_desc._set_attr(op_role_attr_name, backward)
+        if desc.has_attr('op_device'):
+            new_op_desc._set_attr('op_device', desc.attr('op_device'))
         result_descs.append(new_op_desc)
     return result_descs
 
@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
     vars_in_memory = vars_should_be_hold + checkpoints_name
 
     max_calculated_op_position = len(ops)
+    device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
     if recompute_segments == []:
         gap_ops = ops[0:max_calculated_op_position]
         for op in reversed(gap_ops):
@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
                                 _pretty_op_desc_(op.desc, "with_sub_block"))
             grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
                 op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
+            # Set device for grad_op according to forward Op
+            if op.desc.has_attr(device_attr_name):
+                op_device = op.desc.attr(device_attr_name)
+                for op_desc in grad_op_desc:
+                    op_desc._set_attr(device_attr_name, op_device)
             added_descs = _add_descs_to_block(grad_op_desc, local_block)
             grad_op_descs.extend(added_descs)
             grad_to_var.update(op_grad_to_var)
@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
                                 _pretty_op_desc_(op.desc, "with_sub_block"))
             grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
                 op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
+            # Set device for grad_op according to forward Op
+            if op.desc.has_attr(device_attr_name):
+                op_device = op.desc.attr(device_attr_name)
+                for op_desc in grad_op_desc:
+                    op_desc._set_attr(device_attr_name, op_device)
             added_descs = _add_descs_to_block(grad_op_desc, local_block)
             grad_op_descs.extend(added_descs)
             grad_to_var.update(op_grad_to_var)
diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py
index 76c5a30910344506e1ea16f58c9b98862e8eccfc..27ce44a257e78640d83b74d4890ad79893f76187 100755
--- a/python/paddle/fluid/optimizer.py
+++ b/python/paddle/fluid/optimizer.py
@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
         """
         Find the post op that has variable named var_name as input.
         """
+        # bugfix for uniform hybrid parallelism
+        if '.cast_fp32' in var_name:
+            var_name = var_name.replace('.cast_fp32', '')
+        if '.cast_fp16' in var_name:
+            var_name = var_name.replace('.cast_fp16', '')
+
         post_ops = self.input_var_to_op[var_name]
         if post_ops == None: return None
         result_op = None
@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
             # For LRSched ops, we should put them on all sub-programs to
             # make sure each sub-program update the lr correctly
             op._set_attr(self._op_device_key, "gpu:all")
-        elif op.type == "scale" and self._is_backward_op(op):
+        # bugfix in hybrid parallelism
+        elif op.type == "sum" and self._is_backward_op(op):
+            # For sum ops that compute the sum of @RENAMED@ vars
+            for name in op.desc.input_arg_names():
+                assert '@RENAME@' in name, \
+                    "The op must be sum used to accumulate renamed vars."
+            assert len(op.desc.output_arg_names()) == 1
+            out_name = op.desc.output_arg_names()[0]
+            post_op = self._find_post_op(idx, out_name)
+            assert post_op.has_attr(
+                'op_device'), "{} has no op_device attr for var {}".format(
+                    post_op.type, out_name)
+            device = post_op.attr(self._op_device_key)
+            assert device, "The post op must have op_device set."
+            op._set_attr(self._op_device_key, device)
+        elif (op.type == "cast" or
+              op.type == "scale") and self._is_backward_op(op):
             prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
             op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
         elif op.type == "memcpy" and not self._is_optimize_op(op):
@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
         Insert a pair of send and recv ops for every two
         consecutive ops on different devices.
         """
-        extra_index_info = {'index': 0}
-
         # A map from var to device where op takes it as input,
         # avoiding multiple send and recv ops.
         input_var_to_device = dict()
+        # bugfix hybrid parallelism
+        first_optimize_index = None
+        for index, op in enumerate(list(block.ops)):
+            if self._is_optimize_op(op):
+                first_optimize_index = index
+                break
+        extra_index_info = {
+            'index': 0,
+            'first_optimize_index': first_optimize_index
+        }
 
         for index, op in enumerate(list(block.ops)):
             cur_device = op.attr(self._op_device_key)
@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object):
                                 'peer': 1,
                             })
                         extra_index_info['index'] += 1
+                        insert_index = None
+                        if int(op_role) == int(self._op_role.Backward):
+                            insert_index = extra_index_info[
+                                'first_optimize_index']
+                            new_op_role = self._op_role.Optimize
+                        else:
+                            insert_index = index
+                            new_op_role = self._op_role.Backward
                         block._insert_op(
-                            index=index + extra_index_info['index'],
+                            index=insert_index + extra_index_info['index'],
                             type='c_sync_comm_stream',
                             inputs={'X': [var]},
                             outputs={'Out': [var]},
                             attrs={
                                 self._op_device_key: prev_dev,
-                                self._op_role_key: self._op_role.Backward,
+                                self._op_role_key: new_op_role,
                                 'ring_id': ring_id,
                             })
-                        extra_index_info['index'] += 1
+                        if int(op_role) == int(self._op_role.Forward):
+                            extra_index_info['index'] += 1
                         var_shape = list(var.shape)
                         var_shape[0] = self.micro_batch_size if var_shape[
                             0] < 0 else var_shape[0]
@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
 
         # Step4: Special Case: process persistable vars that exist in
         # multiple sections
-        self._process_persistable_vars_in_multi_sections(
-            main_program, startup_program, program_list)
+        # FIXME 
+        # self._process_persistable_vars_in_multi_sections(
+        #     main_program, startup_program, program_list)
 
         # Step5: Add sub blocks for section programs
         self._add_sub_blocks(main_block, program_list)
diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
index 4d6744f2b6fe482f363acaf1b6e93a8c02f2335a..f28bf89ff5c30b86e7af7be1a0dd7d79416d6c98 100755
--- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
             "segment_broadcast_MB": 0.2,
             "segment_anchors": None,
             "sharding_degree": 2,
+            "dp_degree": 2,
             "hybrid_dp": True,
             "gradient_merge_acc_step": 1,
             "mp_degree": 1
@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
             "segment_broadcast_MB": 0.2,
             "segment_anchors": None,
             "sharding_degree": 2,
+            "dp_degree": 2,
             "hybrid_dp": True,
             "gradient_merge_acc_step": 4,
             "mp_degree": 1
@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
         fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
         opt_ops = [op.type for op in train_prog.blocks[2].ops]
         self.assertEqual(fw_bw_ops, [
-            'fill_constant', 'fill_constant', 'fill_constant',
-            'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
-            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
-            'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul',
-            'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
-            'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad',
-            'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
-            'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
-            'tanh_grad', 'elementwise_add_grad', 'mul_grad',
-            'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
-            'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
-            'c_sync_comm_stream', 'elementwise_add', 'elementwise_add',
-            'elementwise_add', 'increment', 'elementwise_mod', 'equal',
-            'conditional_block'
+            'fill_constant',
+            'fill_constant',
+            'fill_constant',
+            'c_sync_calc_stream',
+            'c_broadcast',
+            'c_broadcast',
+            'c_broadcast',
+            'c_broadcast',
+            'c_broadcast',
+            'c_broadcast',
+            'c_sync_comm_stream',
+            'mul',
+            'elementwise_add',
+            'tanh',
+            'mul',
+            'elementwise_add',
+            'tanh',
+            'mul',
+            'elementwise_add',
+            'softmax',
+            'cross_entropy2',
+            'mean',
+            'fill_constant',
+            'scale',
+            'mean_grad',
+            'cross_entropy_grad2',
+            'softmax_grad',
+            'elementwise_add_grad',
+            'mul_grad',
+            'tanh_grad',
+            'elementwise_add_grad',
+            'mul_grad',
+            'tanh_grad',
+            'elementwise_add_grad',
+            'mul_grad',
+            'c_sync_calc_stream',
+            'c_reduce_sum',
+            'c_reduce_sum',
+            'c_reduce_sum',
+            'c_reduce_sum',
+            'c_reduce_sum',
+            'c_reduce_sum',
+            'c_sync_comm_stream',
+            'elementwise_add',
+            'elementwise_add',
+            'elementwise_add',
+            'increment',
+            'elementwise_mod',
+            'equal',
+            'conditional_block',
         ])
         self.assertEqual(opt_ops, [
             'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',