diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index 2205f79ef4633f67f508130db6497394cddad64b..c923624651c6ae2dde27d6cb94add17b9ce5cfe2 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -217,9 +217,13 @@ class RawProgramOptimizer(MetaOptimizerBase): block = self.main_program.global_block() ring_id = self.global_ring_id param_grads = [] + first_backward_idx = -1 # find all grad params - for op in reversed(block.ops): + for idx, op in enumerate(block.ops): + if first_backward_idx == -1 and \ + is_backward_op(op): + first_backward_idx = idx if is_backward_op(op) and \ OP_ROLE_VAR_KEY in op.attr_names: op_role_var = op.attr(OP_ROLE_VAR_KEY) @@ -234,70 +238,100 @@ class RawProgramOptimizer(MetaOptimizerBase): grad = block.var(grad_name) if param.is_distributed: continue - param_grads.append(grad) + param_grads.append((param, grad)) + + outputs_name_to_idx = self.__get_ouputs_name_to_idx(first_backward_idx, + block) - segments = [] + # structure of grad_param_segments is + # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])] + # each entry of the list is a tuple stores the grads segment list and + # the corresponding params segment list + grad_param_segments = [] last_dtype = None # split the grad based on dtype and fused size - for var in param_grads: - if len(segments) == 0 \ - or len(segments[-1]) == self.fuse_grad_size_in_num \ - or var.dtype != last_dtype: - segments.append([var]) - last_dtype = var.dtype + for param, grad in param_grads: + if len(grad_param_segments) == 0 \ + or len(grad_param_segments[-1][0]) == self.fuse_grad_size_in_num \ + or grad.dtype != last_dtype: + grad_param_segments.append(([grad], [param])) + last_dtype = grad.dtype else: - segments[-1].append(var) + grad_param_segments[-1][0].append(grad) + grad_param_segments[-1][1].append(param) - fused_vars = [] - for idx, op in enumerate(block.ops): - if is_optimizer_op(op): - for segment in segments: - # insert coalesce tensor - tmp_var = block.create_var( - name=unique_name.generate('FusedOutput_{}'.format( - segment[0].name)), - dtype=segment[0].dtype, - persistable=True, - stop_gradient=True) - fused_vars.append(tmp_var) - block._insert_op_without_sync( - idx, - type="coalesce_tensor", - inputs={"Input": segment}, - outputs={"Output": segment, - "FusedOutput": tmp_var}, - attrs={ - "copy_data": True, - "use_align": True, - "dtype": segment[0].dtype, - OP_ROLE_KEY: OpRole.Backward - }) - break + if len(grad_param_segments) == 0: + return - # insert the allreduce_sum op - for idx, op in enumerate(block.ops): - if is_optimizer_op(op): - for fused_var in fused_vars: - block._insert_op_without_sync( - idx, - type='c_allreduce_sum', - inputs={'X': fused_var}, - outputs={'Out': fused_var}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': self.calc_comm_same_stream, - OP_ROLE_KEY: OpRole.Backward - }) - if not self.calc_comm_same_stream: - block._insert_op_without_sync( - idx, - type='c_sync_calc_stream', - inputs={'X': fused_var}, - outputs={'Out': fused_var}, - attrs={OP_ROLE_KEY: OpRole.Backward}) - break + fused_vars = [None] * len(grad_param_segments) + for i in range(len(grad_param_segments) - 1, -1, -1): + # travers the grad_param_segments in backward + # not to use reversed since needs the absolute index value + grad_segment, param_segment = grad_param_segments[i] + # insert coalesce tensor + fused_var = block.create_var( + name=unique_name.generate('FusedOutput_{}'.format(grad_segment[ + 0].name)), + dtype=grad_segment[0].dtype, + persistable=False, + stop_gradient=True) + fused_vars[i] = fused_var + after_idx = outputs_name_to_idx[grad_segment[-1]][1] + block._insert_op_without_sync( + after_idx + 1, + type='c_allreduce_sum', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': self.calc_comm_same_stream, + OP_ROLE_KEY: OpRole.Backward + }) + if not self.calc_comm_same_stream: + block._insert_op_without_sync( + after_idx + 1, + type='c_sync_calc_stream', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={OP_ROLE_KEY: OpRole.Backward}) - if len(fused_vars) == 0: + # update the outputs_name_to_idx after insertion of sync/allreduce ops + outputs_name_to_idx = self.__get_ouputs_name_to_idx(first_backward_idx, + block) + # the before_idx is not guaranteed sorted, therefore we have to find the + # topology to insert the coalesce ops + pos_for_coalesce = {} + for i in range(len(grad_param_segments) - 1, -1, -1): + # We separate the insertion of coalesce op and the insertion of sync/allreduce op, + # since that the coalesce op's insertion may invalidate the outputs_name_to_idx + grad_segment, param_segment = grad_param_segments[i] + before_idx = len(block.ops) + for grad in outputs_name_to_idx: + before_idx = min(before_idx, outputs_name_to_idx[grad][0]) + pos_for_coalesce[i] = before_idx + + # insert the coalesce op based on the sorted before_idx + pos_for_coalesce = sorted( + pos_for_coalesce.items(), + key=lambda kv: (kv[1], kv[0]), + reverse=True) + for i, before_idx in pos_for_coalesce: + grad_segment, param_segment = grad_param_segments[i] + fused_var = fused_vars[i] + block._insert_op_without_sync( + before_idx, + type="coalesce_tensor", + inputs={"Input": param_segment}, + outputs={"Output": grad_segment, + "FusedOutput": fused_var}, + attrs={ + "copy_data": False, + "use_align": True, + "dtype": grad_segment[0].dtype, + OP_ROLE_KEY: OpRole.Backward + }) + + if self.calc_comm_same_stream: block._sync_with_cpp() return @@ -307,9 +341,31 @@ class RawProgramOptimizer(MetaOptimizerBase): block._insert_op_without_sync( idx, type='c_sync_comm_stream', - inputs={'X': fused_vars[0]}, - outputs={'Out': fused_vars[0]}, + inputs={'X': grad_segment[0]}, + outputs={'Out': grad_segment[0]}, attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward}) break block._sync_with_cpp() + + def __get_ouputs_name_to_idx(self, first_backward_idx, block): + # Each item of outputs_name_to_idx is a pair of idx. + # The first entry of this pair is the idx of the first op generates the grad, + # which is used to indicate the position to insert coalesce op. + # The second entry of this pair is the idx of the last op generates the grad, + # which is used to indicate the position to insert sync and allreduce op. + outputs_name_to_idx = {} + for idx in range(first_backward_idx, len(block.ops)): + op = block.ops[idx] + if is_optimizer_op(op): + break + for name in op.output_arg_names: + var = block.var(name) + if not outputs_name_to_idx.get(var): + # if the grad only be generated by one op + # the first idx and the last ids are identical + outputs_name_to_idx[var] = (idx, idx) + else: + outputs_name_to_idx[var] = (outputs_name_to_idx[var][0], + idx) + return outputs_name_to_idx