From 201d99d6045a83f8c2cd01da737d7a7d56611a78 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 14 Sep 2022 13:56:56 +0800 Subject: [PATCH] [Auto Parallel] Gradient Fuse Allreduce (#45643) * bugfix (#45332) * dist embedding support lookup table v1 * add unitest * customize wait_comm * group gradients * bugfix * update program --- .../distributed/auto_parallel/engine.py | 34 --- .../paddle/distributed/auto_parallel/utils.py | 12 + .../distributed/passes/auto_parallel_amp.py | 4 +- ...uto_parallel_data_parallel_optimization.py | 272 +++++++++++++++++- .../distributed/passes/auto_parallel_fp16.py | 36 +++ 5 files changed, 319 insertions(+), 39 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 4712634a6c4..5389438d388 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -354,40 +354,6 @@ class Engine: prune_startup_prog = dist_startup_prog._prune(uninitialized) self._executor.run(prune_startup_prog) - if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']: - # from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 - def cast_parameters_to_fp16(place, - program, - scope=None, - to_fp16_var_names=None): - """ - Traverse all parameters in the whole model and set them to the FP16 data type. - Whereas, this function will keep parameters of batchnorms in FP32. - Args: - place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors. - program (Program): The used program. - scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. - Default is None. - to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` - will be set to FP16. Usually, it is the returned - value of `cast_model_to_fp16` API. - """ - from paddle.framework import core - import numpy as np - all_parameters = [] - for block in program.blocks: - all_parameters.extend(block.all_parameters()) - - var_scope = scope if scope else paddle.static.global_scope() - for param in all_parameters: - if param.dtype == core.VarDesc.VarType.FP16: - param_t = var_scope.find_var( - param.name).get_tensor() - data = np.array(param_t) - param_t.set(np.float16(data), place) - - cast_parameters_to_fp16(place, prune_startup_prog) - def fit(self, train_data, batch_size=1, diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index cfe21612868..8813bbe5449 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1504,3 +1504,15 @@ def ring_id_to_process_group(ring_id): if g.id == ring_id: return g return None + + +def find_higher_order_backward_op(program): + + higher_order_op_suffix = ['_grad_grad', 'triple_grad'] + for block in program.blocks: + for op in block.ops: + for suffix in higher_order_op_suffix: + if suffix in op.type: + return True + + return False diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 3f3448b5008..458cb26ccd4 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -314,7 +314,9 @@ class AMPState(object): consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) else: - assert in_var.dtype == dst_dtype + assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( + grad_op.type, in_name, dst_dtype, in_var.dtype, + str(grad_op)) for out_name in grad_op.output_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 586ad235fd1..ac923be9a1a 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -13,12 +13,14 @@ # limitations under the License. from collections import OrderedDict +import numpy as np import paddle +from paddle.fluid import core, unique_name from paddle.fluid.framework import default_main_program -from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op -from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group +from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op from .pass_base import PassBase, PassType, register_pass # add new optimizers supporting rescale_grad here @@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [ __max_stream_num_allow__ = 16 +def numel(var): + return np.prod(list(var.shape)) + + @register_pass("auto_parallel_data_parallel_optimization") class DataParallelOptimizationPass(PassBase): """ @@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase): self._analyze_program() self._prune_grad_scaling() self._calc_comm_overlap() - self._fuse_allreduce() + grad_group = self._fuse_allreduce() + + # self.summary(grad_group) def _prune_grad_scaling(self): @@ -99,7 +107,19 @@ class DataParallelOptimizationPass(PassBase): self._calc_wait_comms() def _fuse_allreduce(self): - pass + + if not self._could_be_fuse(): + return [] + + with open('./before_program.txt.' + str(paddle.distributed.get_rank()), + 'w') as f: + f.write(str(default_main_program())) + grad_group = self._group_grads() + self._update_program(grad_group) + with open('./after_program.txt.' + str(paddle.distributed.get_rank()), + 'w') as f: + f.write(str(default_main_program())) + return grad_group def _analyze_program(self): """ @@ -316,3 +336,247 @@ class DataParallelOptimizationPass(PassBase): 'op_role': OpRole.Backward, 'ring_id': ring_id }) + + def _could_be_fuse(self): + # TODO support gradient fuse higher order gradient. + # should analyse the dependencies of gradient in backward. + if find_higher_order_backward_op(default_main_program()): + return False + if self.use_sharding: + return False + return True + + def _group_grads(self): + """ + conditions for gradients to be grouped: + 1. group size < max_fuse_numel + 2. same dp group + 3. same dtype + 4. dependency: grad would NOT be used by other ops within group segment + + gradients inside same group would be fuse into one coalesce tensor + """ + + block = default_main_program().global_block() + ops = block.ops + + # group individual grad vars + # TODO consider fuse gradient for sharding reduce + # TODO let user to set fuse_grad_size + # emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h + h = 2048 + ffn_numel = 2 * (4 * h) * h + mha_numel = 3 * h * h + h * h + max_fuse_numel = ffn_numel + mha_numel + grad_groups = [] + cur_group = GradientsGroup(ops, max_fuse_numel) + grouped_grad_names = set() + + def collect_group(cur_group, grad_var, ring_id, i): + if len(cur_group.gradients) == 0: + cur_group = None + elif len(cur_group.gradients) == 1: + grouped_grad_names.remove(cur_group.gradients[0].name) + else: + cur_group.finalize() + grad_groups.append(cur_group) + + new_group = GradientsGroup(ops, max_fuse_numel) + if grad_var: + new_group.add(grad_var, ring_id, i) + grouped_grad_names.add(grad_var.name) + return new_group + + def op_depend_on_group(op, group): + vars_ = set(op.input_arg_names + op.output_arg_names) + grad_names = set([grad.name for grad in group.gradients]) + return len(vars_.intersection(grad_names)) > 0 + + for i, op in enumerate(ops): + if is_data_parallel_reduce_op(op): + ring_id = op.attr("ring_id") + grad_name = op.output_arg_names[0] + grad_var = block.var(grad_name) + grad_numel = numel(grad_var) + + if cur_group.acceptable(grad_var, ring_id): + assert grad_name not in grouped_grad_names + grouped_grad_names.add(grad_name) + cur_group.add(grad_var, ring_id, i) + else: + cur_group = collect_group(cur_group, grad_var, ring_id, i) + else: + if op_depend_on_group(op, cur_group): + cur_group = collect_group(cur_group, None, None, None) + + # collect last group + collect_group(cur_group, None, None, None) + + return grad_groups + + def _update_program(self, grad_groups): + + block = default_main_program().global_block() + + remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute'] + + for i, group in enumerate(grad_groups[::-1]): + + # create coalecse tensor + group.coalesce_var = block.create_var(name=unique_name.generate( + 'coalecse_grad_{}'.format(i)), + dtype=group.dtype, + persistable=False, + stop_gradient=True) + + # update allreduce & scale op + if group.scale_op_idx != -1: + scale_op = block.ops[group.scale_op_idx] + assert scale_op.type == 'scale', "should found scale op but found {}".format( + str(scale_op)) + scale_op._rename_input(scale_op.input_arg_names[0], + group.coalesce_var.name) + scale_op._rename_output(scale_op.output_arg_names[0], + group.coalesce_var.name) + + allreduce_op = block.ops[group.allreduce_op_idx] + assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format( + str(allreduce_op)) + allreduce_op._rename_input(allreduce_op.input_arg_names[0], + group.coalesce_var.name) + allreduce_op._rename_output(allreduce_op.output_arg_names[0], + group.coalesce_var.name) + + # remvoe un-used op + remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices + for idx in sorted(remove_op_indices, reverse=True): + assert block.ops[ + idx].type in remove_op_types, "Unexception: try to remove op {}".format( + str(op)) + block._remove_op(idx) + + # insert coalecse op + concated_shapes = [] + concated_ranks = [] + for grad_ in group.gradients: + shape = grad_.shape + concated_shapes.extend(shape) + concated_ranks.append(len(shape)) + + grad_names = [grad.name for grad in group.gradients] + block._insert_op_without_sync(group.coalesce_op_idx, + type="coalesce_tensor", + inputs={"Input": grad_names}, + outputs={ + "Output": grad_names, + "FusedOutput": group.coalesce_var + }, + attrs={ + "copy_data": False, + "use_align": True, + "dtype": group.dtype, + "concated_shapes": + concated_shapes, + "concated_ranks": concated_ranks, + OP_ROLE_KEY: OpRole.Backward + }) + + block._sync_with_cpp() + # TODO update dist attr + + def summary(self, grad_groups=[]): + # TODO: add logger module + import logging + self._logger = logging.getLogger() + self._logger.propagate = False + if not self._logger.handlers: + self._logger.setLevel(logging.INFO) + log_handler = logging.StreamHandler() + log_format = logging.Formatter( + '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' + ) + log_handler.setFormatter(log_format) + self._logger.addHandler(log_handler) + + if len(grad_groups) > 0: + self._logger.info( + "origin {} allreduce ops are fused into {} coalecse allreduce ops." + .format(len(self._grad_name_to_group_map.keys()), + len(grad_groups))) + self._logger.info("gradient fusing group are following: ") + fused_grads = set() + for i, group in enumerate(grad_groups): + self._logger.info( + "coalecse gradient [{}] is composed by: {}".format( + i, [grad.name for grad in group.gradients])) + fused_grads.update([grad.name for grad in group.gradients]) + individual_grads = set( + self._grad_name_to_group_map.keys()) - set(fused_grads) + self._logger.info( + "the following [{}] gradients are not fused: ".format( + len(individual_grads))) + self._logger.info("individual gradient {}".format(individual_grads)) + + +class GradientsGroup(object): + + def __init__(self, ops, max_group_size): + self.max_group_size = max_group_size + self.ops = ops + + self.gradients = [] + self.numel = 0 + self.dtype = None + self.ring_id = None + self.coalesce_var = None + self.coalesce_op_idx = -1 + self.allreduce_op_idx = -1 + self.scale_op_idx = -1 + self.remove_wait_op_indices = [] + self.remove_allreduce_op_indices = [] + self.remove_scale_op_indices = [] + + def acceptable(self, grad_var, ring_id): + if len(self.gradients) == 0: + return True + if ring_id != self.ring_id: + return False + if numel(grad_var) + self.numel > self.max_group_size: + return False + if grad_var.dtype != self.dtype: + return False + + return True + + def add(self, grad_var, ring_id, i): + self.gradients.append(grad_var) + self.ring_id = ring_id + self.dtype = grad_var.dtype + self.numel += numel(grad_var) + + # remove auxiliary ops in non-fuse dp allreduce + self.remove_allreduce_op_indices.append(i) + + # NOTE this pass rely on the original synchronization add in previous passes + # (same stream or calc_wait_comm & comm_wait_calc) + # to guarantee the correctness of comm_calc execution order. + # so the calc_wait_comm should be keep. + grad_op_idx = i - 1 + if i > 0 and self.ops[i - 1].type == 'c_wait_compute': + self.remove_wait_op_indices.append(i - 1) + grad_op_idx -= 1 + if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]): + self.remove_scale_op_indices.append(i + 1) + + if len(self.gradients) == 1: + grad_op = self.ops[grad_op_idx] + assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format( + grad_var.name, str(grad_op)) + self.coalesce_op_idx = grad_op_idx + + def finalize(self): + self.allreduce_op_idx = self.remove_allreduce_op_indices.pop() + if len(self.remove_wait_op_indices) > 1: + self.remove_wait_op_indices.pop() + if len(self.remove_scale_op_indices) > 1: + self.scale_op_idx = self.remove_scale_op_indices.pop() diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 7702de7c01e..07fd1d60043 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -16,6 +16,7 @@ from collections import defaultdict import paddle from paddle.framework import core +from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid import unique_name from .pass_base import register_pass from paddle.fluid.data_feeder import check_variable_and_dtype, check_type @@ -536,6 +537,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): return output_var +def cast_startup_program(): + main_program = default_main_program() + startup_program = default_startup_program() + + param_to_dtype = {} + for block in main_program.blocks: + for p in block.all_parameters(): + param_to_dtype[p.name] = p.dtype + + def is_initialization_op(op): + comm_op_prefix = "c_" + op_type = op.type + if op_type.startswith(comm_op_prefix): + return False + + if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0: + return False + + return True + + for op in startup_program.global_block().ops: + if is_initialization_op(op): + output_name = op.output_arg_names[0] + if param_to_dtype.get(output_name, + None) == core.VarDesc.VarType.FP16: + assert op.has_attr( + 'dtype' + ), "initialization op is supported to has dtype attribute but got {}.".format( + str(op)) + if op.attr('dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.FP16) + + @register_pass("auto_parallel_fp16") class FP16Pass(AMPPass): @@ -563,6 +597,8 @@ class FP16Pass(AMPPass): input_data_var_names) is_train = fp16_state._build_state() + cast_startup_program() + if is_train: with paddle.static.program_guard(main_program, startup_program): # TODO (JZ-LIANG)support cast forward program only when inference -- GitLab