From b56dbe08a3354173721b2596d5a54a06b6a7c725 Mon Sep 17 00:00:00 2001
From: Yuang Liu <liuyuang@baidu.com>
Date: Thu, 29 Jul 2021 15:16:01 +0800
Subject: [PATCH] fix the allreduce fused bug, test=develop (#34446)

---
 .../framework/distributed_strategy.proto      |   2 +-
 .../meta_optimizers/raw_program_optimizer.py  | 257 +++++-------------
 2 files changed, 66 insertions(+), 193 deletions(-)

diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto
index 0a94b897b9..dabe216068 100644
--- a/paddle/fluid/framework/distributed_strategy.proto
+++ b/paddle/fluid/framework/distributed_strategy.proto
@@ -188,7 +188,7 @@ message DistributedStrategy {
   optional bool find_unused_parameters = 28 [ default = false ];
   optional bool tensor_parallel = 29 [ default = false ];
   optional bool without_graph_optimization = 30 [ default = false ];
-  optional int32 fuse_grad_size_in_num = 31 [ default = 1 ];
+  optional int32 fuse_grad_size_in_num = 31 [ default = 8 ];
   optional bool calc_comm_same_stream = 32 [ default = false ];
   optional bool asp = 33 [ default = false ];
 
diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
index c85242b6a5..2205f79ef4 100755
--- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
+++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
@@ -131,7 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
 
     def _transpile_main_program(self, loss):
         self._insert_loss_grad_ops(loss)
-        if self.fuse_all_reduce_ops:
+        if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
             self._allreduce_fusion_program()
         else:
             self._insert_allreduce_ops()
@@ -216,11 +216,10 @@ class RawProgramOptimizer(MetaOptimizerBase):
     def _allreduce_fusion_program(self):
         block = self.main_program.global_block()
         ring_id = self.global_ring_id
-        record_idx, allreduce_input_vars, allreduce_output_vars = [], [], []
-        ops = list(enumerate(block.ops))
+        param_grads = []
 
-        for idx, op in reversed(ops):
-            # we travers the ops reversely
+        # find all grad params
+        for op in reversed(block.ops):
             if is_backward_op(op) and \
                     OP_ROLE_VAR_KEY in op.attr_names:
                 op_role_var = op.attr(OP_ROLE_VAR_KEY)
@@ -229,214 +228,88 @@ class RawProgramOptimizer(MetaOptimizerBase):
                 assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \
                                                   "but got odd number of vars"
                 for i in range(0, len(op_role_var), 2):
-                    # handle vars in each op, each time handle a param and a grad
                     param_name = op_role_var[i]
                     param = block.var(param_name)
                     grad_name = op_role_var[i + 1]
                     grad = block.var(grad_name)
                     if param.is_distributed:
                         continue
-                    if ".cast_fp16@GRAD" in grad_name:
-                        # when amp=True get the fp16 param
-                        param_name = param_name + ".cast_fp16"
-                        if not block.has_var(param_name):
-                            raise ValueError("op cast name error {}".format(
-                                op.type))
-                        else:
-                            param = block.var(param_name)
-
-                    if len(allreduce_output_vars) == 0 or \
-                            len(allreduce_output_vars[-1]) == \
-                            self.fuse_grad_size_in_num:
-                        # start of the fusion or last group meets the config size
-                        allreduce_output_vars.append([grad])
-                        allreduce_input_vars.append([param])
-                        # add the start and end idx to the record idx
-                        record_idx.append([idx, idx])
-                    else:
-                        # Current group's size is below the config size
-                        # append grad and param to the last group (current group)
-                        # update the start idx to current op's idx
-                        # Since we travers the ops reversely, the idx is descending
-                        # we update the first entry of each entry for record_idx
-                        allreduce_output_vars[-1].append(grad)
-                        allreduce_input_vars[-1].append(param)
-                        record_idx[-1][0] = idx
-
-        assert len(allreduce_output_vars) == len(
-            record_idx
-        ), "It has different lens between the allreduce_output_vars and record_idx."
-
-        if not allreduce_output_vars or not allreduce_input_vars:
-            # nothing needs to be allreduced
-            return
+                    param_grads.append(grad)
+
+        segments = []
+        last_dtype = None
+        # split the grad based on dtype and fused size
+        for var in param_grads:
+            if len(segments) == 0 \
+                    or len(segments[-1]) == self.fuse_grad_size_in_num \
+                    or var.dtype != last_dtype:
+                segments.append([var])
+                last_dtype = var.dtype
+            else:
+                segments[-1].append(var)
 
-        self.vars = collections.OrderedDict()
-        index, pos, offset = 0, 0, 0
-        start, end = record_idx[index]
-        for idx, op in reversed(ops):
-            if idx == start:
-                pos = 0
-                done_output_vars, done_input_vars = self._split_fuction(
-                    allreduce_output_vars[index],  # grad
-                    allreduce_input_vars[index]  # param
-                )
-                for id_, done_output_var in enumerate(done_output_vars):
+        fused_vars = []
+        for idx, op in enumerate(block.ops):
+            if is_optimizer_op(op):
+                for segment in segments:
+                    # insert coalesce tensor
                     tmp_var = block.create_var(
                         name=unique_name.generate('FusedOutput_{}'.format(
-                            done_output_var[0].name)),
-                        dtype=done_output_var[0].dtype,
-                        persistable=False,
+                            segment[0].name)),
+                        dtype=segment[0].dtype,
+                        persistable=True,
                         stop_gradient=True)
-                    self.vars['FusedOutput_{}'.format(done_output_var[0]
-                                                      .name)] = tmp_var
-
-                    block._insert_op(
-                        idx + id_,
+                    fused_vars.append(tmp_var)
+                    block._insert_op_without_sync(
+                        idx,
                         type="coalesce_tensor",
-                        inputs={"Input": done_input_vars[id_]},
-                        outputs={
-                            "Output": done_output_var,
-                            "FusedOutput": tmp_var
-                        },
+                        inputs={"Input": segment},
+                        outputs={"Output": segment,
+                                 "FusedOutput": tmp_var},
                         attrs={
-                            "copy_data": False,
+                            "copy_data": True,
                             "use_align": True,
-                            "dtype": done_output_var[0].dtype,
+                            "dtype": segment[0].dtype,
                             OP_ROLE_KEY: OpRole.Backward
                         })
-                    pos += 1
-
-                for id_ in range(len(done_output_vars)):
-                    x = self.vars['FusedOutput_{}'.format(done_output_vars[id_][
-                        0].name)]
-                    out = x
-
-                    # NOTE: there still some optimize space if use EVENT instead of sync
-                    if not self.calc_comm_same_stream:
-                        # need sync if the calc and comm stream are not the same
-                        block._insert_op(
-                            end + id_ + pos + 1,
-                            type='c_sync_calc_stream',
-                            inputs={'X': x},
-                            outputs={'Out': out},
-                            attrs={OP_ROLE_KEY: OpRole.Backward})
+                break
 
-                    block._insert_op(
-                        end + id_ + pos + 1
-                        if self.calc_comm_same_stream else end + id_ + pos + 2,
+        # insert the allreduce_sum op
+        for idx, op in enumerate(block.ops):
+            if is_optimizer_op(op):
+                for fused_var in fused_vars:
+                    block._insert_op_without_sync(
+                        idx,
                         type='c_allreduce_sum',
-                        inputs={'X': x},
-                        outputs={'Out': out},
+                        inputs={'X': fused_var},
+                        outputs={'Out': fused_var},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': self.calc_comm_same_stream,
                             OP_ROLE_KEY: OpRole.Backward
                         })
+                    if not self.calc_comm_same_stream:
+                        block._insert_op_without_sync(
+                            idx,
+                            type='c_sync_calc_stream',
+                            inputs={'X': fused_var},
+                            outputs={'Out': fused_var},
+                            attrs={OP_ROLE_KEY: OpRole.Backward})
+                break
 
-                index += 1
-                if len(record_idx) == index:
-                    break
-                start, end = record_idx[index]
-
-        if not self.calc_comm_same_stream:
-            # need sync if the calc and comm stream are not the same
-            for idx, op in enumerate(block.ops):
-                if is_optimizer_op(op):
-                    block._insert_op(
-                        idx,
-                        type='c_sync_comm_stream',
-                        inputs={'X': block.create_var()},
-                        outputs={'Out': block.create_var()},
-                        attrs={
-                            'ring_id': ring_id,
-                            OP_ROLE_KEY: OpRole.Backward
-                        })
-                    break
-
-    # Integrate grads of the same type to form a combination.
-    # If combination is selected, will return grads of the same type in a groups.
-    # For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)]
-    def _split_fuction(self,
-                       allreduce_output_vars,
-                       allreduce_input_vars,
-                       combination=True):
-        input_vars, final_input_vars, output_vars, final_output_vars = [], [], [], []
-        if len(allreduce_output_vars) == 1:
-            # only have one var to handle
-            final_output_vars.append(allreduce_output_vars)
-            final_input_vars.append(allreduce_input_vars)
-            return final_output_vars, final_input_vars
-
-        for idx in range(len(allreduce_input_vars) - 1):
-            # the last var needs to be handled differently
-            if allreduce_input_vars[idx].dtype == allreduce_input_vars[idx +
-                                                                       1].dtype:
-                # if current var and next var are in same type
-                # append current var to input_vars
-                input_vars.append(allreduce_input_vars[idx])
-                if idx == len(allreduce_input_vars) - 2:
-                    # if current var is the second last var
-                    # append the last var to input_vars
-                    # and update the final_input_vars
-                    input_vars.append(allreduce_input_vars[idx + 1])
-                    final_input_vars.append(input_vars)
-            else:
-                # the current var and next var are in different types
-                # append current var to input_vars
-                # update the final_input_vars
-                # reset input_vars to receive a new type
-                input_vars.append(allreduce_input_vars[idx])
-                final_input_vars.append(input_vars)
-                input_vars = []
-                if idx == len(allreduce_input_vars) - 2:
-                    # if current var is the second last var
-                    # append the last var to a reset input_vars since they are in different types
-                    # and update the final_input_vars
-                    input_vars.append(allreduce_input_vars[idx + 1])
-                    final_input_vars.append(input_vars)
-
-        for idx in range(len(allreduce_output_vars) - 1):
-            # the procedure for the output vars is the same with that for the input vars
-            if allreduce_output_vars[idx].dtype == allreduce_output_vars[
-                    idx + 1].dtype:
-                output_vars.append(allreduce_output_vars[idx])
-                if idx == len(allreduce_output_vars) - 2:
-                    output_vars.append(allreduce_output_vars[idx + 1])
-                    final_output_vars.append(output_vars)
-            else:
-                output_vars.append(allreduce_output_vars[idx])
-                final_output_vars.append(output_vars)
-                output_vars = []
-                if idx == len(allreduce_output_vars) - 2:
-                    output_vars.append(allreduce_output_vars[idx + 1])
-                    final_output_vars.append(output_vars)
-
-        # at this time, all vars in each group in final_input_vars and final_output_vars are in the same type
-
-        if combination:
-            input_fp16_vars, input_fp32_vars, output_fp16_vars, output_fp32_vars = [], [], [], []
-            for final_input_var in final_input_vars:
-                if final_input_var[0].dtype == core.VarDesc.VarType.FP16:
-                    # extend the group
-                    input_fp16_vars.extend(final_input_var)
-                else:
-                    input_fp32_vars.extend(final_input_var)
-
-            for final_output_var in final_output_vars:
-                if final_output_var[0].dtype == core.VarDesc.VarType.FP16:
-                    output_fp16_vars.extend(final_output_var)
-                else:
-                    output_fp32_vars.extend(final_output_var)
-
-            final_output_vars, final_input_vars = [], []
-            if output_fp16_vars:
-                final_output_vars.append(output_fp16_vars)
-            if output_fp32_vars:
-                final_output_vars.append(output_fp32_vars)
-            if input_fp16_vars:
-                final_input_vars.append(input_fp16_vars)
-            if input_fp32_vars:
-                final_input_vars.append(input_fp32_vars)
+        if len(fused_vars) == 0:
+            block._sync_with_cpp()
+            return
 
-        return final_output_vars, final_input_vars
+        # insert the sync comm op
+        for idx, op in enumerate(block.ops):
+            if is_optimizer_op(op):
+                block._insert_op_without_sync(
+                    idx,
+                    type='c_sync_comm_stream',
+                    inputs={'X': fused_vars[0]},
+                    outputs={'Out': fused_vars[0]},
+                    attrs={'ring_id': ring_id,
+                           OP_ROLE_KEY: OpRole.Backward})
+                break
+        block._sync_with_cpp()
-- 
GitLab