diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 0a94b897b9b1d594e375d1e9276214d55cfe74fd..dabe21606898783676646d15aaafd9fb4557b928 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -188,7 +188,7 @@ message DistributedStrategy { optional bool find_unused_parameters = 28 [ default = false ]; optional bool tensor_parallel = 29 [ default = false ]; optional bool without_graph_optimization = 30 [ default = false ]; - optional int32 fuse_grad_size_in_num = 31 [ default = 1 ]; + optional int32 fuse_grad_size_in_num = 31 [ default = 8 ]; optional bool calc_comm_same_stream = 32 [ default = false ]; optional bool asp = 33 [ default = false ]; diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index c85242b6a562b14b5de131a7004f701b4e356f85..2205f79ef4633f67f508130db6497394cddad64b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -131,7 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase): def _transpile_main_program(self, loss): self._insert_loss_grad_ops(loss) - if self.fuse_all_reduce_ops: + if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1: self._allreduce_fusion_program() else: self._insert_allreduce_ops() @@ -216,11 +216,10 @@ class RawProgramOptimizer(MetaOptimizerBase): def _allreduce_fusion_program(self): block = self.main_program.global_block() ring_id = self.global_ring_id - record_idx, allreduce_input_vars, allreduce_output_vars = [], [], [] - ops = list(enumerate(block.ops)) + param_grads = [] - for idx, op in reversed(ops): - # we travers the ops reversely + # find all grad params + for op in reversed(block.ops): if is_backward_op(op) and \ OP_ROLE_VAR_KEY in op.attr_names: op_role_var = op.attr(OP_ROLE_VAR_KEY) @@ -229,214 +228,88 @@ class RawProgramOptimizer(MetaOptimizerBase): assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \ "but got odd number of vars" for i in range(0, len(op_role_var), 2): - # handle vars in each op, each time handle a param and a grad param_name = op_role_var[i] param = block.var(param_name) grad_name = op_role_var[i + 1] grad = block.var(grad_name) if param.is_distributed: continue - if ".cast_fp16@GRAD" in grad_name: - # when amp=True get the fp16 param - param_name = param_name + ".cast_fp16" - if not block.has_var(param_name): - raise ValueError("op cast name error {}".format( - op.type)) - else: - param = block.var(param_name) - - if len(allreduce_output_vars) == 0 or \ - len(allreduce_output_vars[-1]) == \ - self.fuse_grad_size_in_num: - # start of the fusion or last group meets the config size - allreduce_output_vars.append([grad]) - allreduce_input_vars.append([param]) - # add the start and end idx to the record idx - record_idx.append([idx, idx]) - else: - # Current group's size is below the config size - # append grad and param to the last group (current group) - # update the start idx to current op's idx - # Since we travers the ops reversely, the idx is descending - # we update the first entry of each entry for record_idx - allreduce_output_vars[-1].append(grad) - allreduce_input_vars[-1].append(param) - record_idx[-1][0] = idx - - assert len(allreduce_output_vars) == len( - record_idx - ), "It has different lens between the allreduce_output_vars and record_idx." - - if not allreduce_output_vars or not allreduce_input_vars: - # nothing needs to be allreduced - return + param_grads.append(grad) + + segments = [] + last_dtype = None + # split the grad based on dtype and fused size + for var in param_grads: + if len(segments) == 0 \ + or len(segments[-1]) == self.fuse_grad_size_in_num \ + or var.dtype != last_dtype: + segments.append([var]) + last_dtype = var.dtype + else: + segments[-1].append(var) - self.vars = collections.OrderedDict() - index, pos, offset = 0, 0, 0 - start, end = record_idx[index] - for idx, op in reversed(ops): - if idx == start: - pos = 0 - done_output_vars, done_input_vars = self._split_fuction( - allreduce_output_vars[index], # grad - allreduce_input_vars[index] # param - ) - for id_, done_output_var in enumerate(done_output_vars): + fused_vars = [] + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + for segment in segments: + # insert coalesce tensor tmp_var = block.create_var( name=unique_name.generate('FusedOutput_{}'.format( - done_output_var[0].name)), - dtype=done_output_var[0].dtype, - persistable=False, + segment[0].name)), + dtype=segment[0].dtype, + persistable=True, stop_gradient=True) - self.vars['FusedOutput_{}'.format(done_output_var[0] - .name)] = tmp_var - - block._insert_op( - idx + id_, + fused_vars.append(tmp_var) + block._insert_op_without_sync( + idx, type="coalesce_tensor", - inputs={"Input": done_input_vars[id_]}, - outputs={ - "Output": done_output_var, - "FusedOutput": tmp_var - }, + inputs={"Input": segment}, + outputs={"Output": segment, + "FusedOutput": tmp_var}, attrs={ - "copy_data": False, + "copy_data": True, "use_align": True, - "dtype": done_output_var[0].dtype, + "dtype": segment[0].dtype, OP_ROLE_KEY: OpRole.Backward }) - pos += 1 - - for id_ in range(len(done_output_vars)): - x = self.vars['FusedOutput_{}'.format(done_output_vars[id_][ - 0].name)] - out = x - - # NOTE: there still some optimize space if use EVENT instead of sync - if not self.calc_comm_same_stream: - # need sync if the calc and comm stream are not the same - block._insert_op( - end + id_ + pos + 1, - type='c_sync_calc_stream', - inputs={'X': x}, - outputs={'Out': out}, - attrs={OP_ROLE_KEY: OpRole.Backward}) + break - block._insert_op( - end + id_ + pos + 1 - if self.calc_comm_same_stream else end + id_ + pos + 2, + # insert the allreduce_sum op + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + for fused_var in fused_vars: + block._insert_op_without_sync( + idx, type='c_allreduce_sum', - inputs={'X': x}, - outputs={'Out': out}, + inputs={'X': fused_var}, + outputs={'Out': fused_var}, attrs={ 'ring_id': ring_id, 'use_calc_stream': self.calc_comm_same_stream, OP_ROLE_KEY: OpRole.Backward }) + if not self.calc_comm_same_stream: + block._insert_op_without_sync( + idx, + type='c_sync_calc_stream', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={OP_ROLE_KEY: OpRole.Backward}) + break - index += 1 - if len(record_idx) == index: - break - start, end = record_idx[index] - - if not self.calc_comm_same_stream: - # need sync if the calc and comm stream are not the same - for idx, op in enumerate(block.ops): - if is_optimizer_op(op): - block._insert_op( - idx, - type='c_sync_comm_stream', - inputs={'X': block.create_var()}, - outputs={'Out': block.create_var()}, - attrs={ - 'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Backward - }) - break - - # Integrate grads of the same type to form a combination. - # If combination is selected, will return grads of the same type in a groups. - # For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)] - def _split_fuction(self, - allreduce_output_vars, - allreduce_input_vars, - combination=True): - input_vars, final_input_vars, output_vars, final_output_vars = [], [], [], [] - if len(allreduce_output_vars) == 1: - # only have one var to handle - final_output_vars.append(allreduce_output_vars) - final_input_vars.append(allreduce_input_vars) - return final_output_vars, final_input_vars - - for idx in range(len(allreduce_input_vars) - 1): - # the last var needs to be handled differently - if allreduce_input_vars[idx].dtype == allreduce_input_vars[idx + - 1].dtype: - # if current var and next var are in same type - # append current var to input_vars - input_vars.append(allreduce_input_vars[idx]) - if idx == len(allreduce_input_vars) - 2: - # if current var is the second last var - # append the last var to input_vars - # and update the final_input_vars - input_vars.append(allreduce_input_vars[idx + 1]) - final_input_vars.append(input_vars) - else: - # the current var and next var are in different types - # append current var to input_vars - # update the final_input_vars - # reset input_vars to receive a new type - input_vars.append(allreduce_input_vars[idx]) - final_input_vars.append(input_vars) - input_vars = [] - if idx == len(allreduce_input_vars) - 2: - # if current var is the second last var - # append the last var to a reset input_vars since they are in different types - # and update the final_input_vars - input_vars.append(allreduce_input_vars[idx + 1]) - final_input_vars.append(input_vars) - - for idx in range(len(allreduce_output_vars) - 1): - # the procedure for the output vars is the same with that for the input vars - if allreduce_output_vars[idx].dtype == allreduce_output_vars[ - idx + 1].dtype: - output_vars.append(allreduce_output_vars[idx]) - if idx == len(allreduce_output_vars) - 2: - output_vars.append(allreduce_output_vars[idx + 1]) - final_output_vars.append(output_vars) - else: - output_vars.append(allreduce_output_vars[idx]) - final_output_vars.append(output_vars) - output_vars = [] - if idx == len(allreduce_output_vars) - 2: - output_vars.append(allreduce_output_vars[idx + 1]) - final_output_vars.append(output_vars) - - # at this time, all vars in each group in final_input_vars and final_output_vars are in the same type - - if combination: - input_fp16_vars, input_fp32_vars, output_fp16_vars, output_fp32_vars = [], [], [], [] - for final_input_var in final_input_vars: - if final_input_var[0].dtype == core.VarDesc.VarType.FP16: - # extend the group - input_fp16_vars.extend(final_input_var) - else: - input_fp32_vars.extend(final_input_var) - - for final_output_var in final_output_vars: - if final_output_var[0].dtype == core.VarDesc.VarType.FP16: - output_fp16_vars.extend(final_output_var) - else: - output_fp32_vars.extend(final_output_var) - - final_output_vars, final_input_vars = [], [] - if output_fp16_vars: - final_output_vars.append(output_fp16_vars) - if output_fp32_vars: - final_output_vars.append(output_fp32_vars) - if input_fp16_vars: - final_input_vars.append(input_fp16_vars) - if input_fp32_vars: - final_input_vars.append(input_fp32_vars) + if len(fused_vars) == 0: + block._sync_with_cpp() + return - return final_output_vars, final_input_vars + # insert the sync comm op + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + block._insert_op_without_sync( + idx, + type='c_sync_comm_stream', + inputs={'X': fused_vars[0]}, + outputs={'Out': fused_vars[0]}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + break + block._sync_with_cpp()