From 926c4bd2fe030528aedc413950b5a14c3ea09495 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 9 Jan 2023 10:32:48 +0800 Subject: [PATCH] [AutoParalle] balancing the calculation of global_norm in data parallel (#49510) * [AutoParalle] balancing the calculation of global_norm in data parallel * fix unittest * update cond pure_data_parallel --- .../passes/auto_parallel_grad_clip.py | 157 +++++++++++++----- .../auto_parallel/clip_grad_by_global_norm.py | 1 + 2 files changed, 120 insertions(+), 38 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 0ff3fbcf95..25a768e94d 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -21,8 +21,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid.executor import _is_enable_standalone_executor from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr -from ..auto_parallel.operators.common import SyncMode -from ..auto_parallel.process_group import get_world_process_group +from ..auto_parallel.operators.common import ( + SyncMode, + is_data_parallel_reduce_op, +) +from ..auto_parallel.process_group import ( + get_all_process_groups, + get_world_process_group, +) from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.reshard import Resharder from ..auto_parallel.utils import ( @@ -31,6 +37,7 @@ from ..auto_parallel.utils import ( is_gradient_clip_op, is_optimize_op, ) +from .auto_parallel_sharding import ShardingPass from .pass_base import PassBase, register_pass @@ -145,46 +152,65 @@ def _is_about_global_norm( class ClipHelper: - def __init__(self, params_grads, rank_id, block, dist_context): + def __init__( + self, params_grads, rank_id, block, dist_context, pass_context + ): params, _ = zip(*params_grads) self.params = list(params) self.params_name = [p.name for p in self.params] self.rank_id = rank_id self.block = block self.dist_context = dist_context + self.pass_context = pass_context self.sharding_group = None self.world_ranks = get_world_process_group().ranks if hasattr(dist_context, '_sharding_group'): self.sharding_group = dist_context._sharding_group - def _is_calcuate_norm(self, name): - if not self._is_local_param(name): - return False, [] + self.world_nranks = len(self.world_ranks) + self.pure_data_parallel = self._is_pure_data_parallel() + self.rank_to_params = self._partition_parameters(params) - param = self.params[self.params_name.index(name)] - dist_attr = self._get_dist_attr(name) - topology = dist_attr.process_mesh.shape - processes = dist_attr.process_mesh.process_ids - dims_mapping = dist_attr.dims_mapping - return _is_about_global_norm( - self.rank_id, - param.shape, - topology, - processes, - dims_mapping, - self.sharding_group, - ) + def is_calcuate_norm(self, name): + """ + whether the param_name@GRAD paticipate in the calculation of global_norm + """ + if not self.is_local_param(name): + return False - def _get_dist_attr(self, name): - var = self.block.vars[name] - return self.dist_context.get_tensor_dist_attr_for_program(var) + param = self.params[self.params_name.index(name)] + if not self.pure_data_parallel: + dist_attr = self._get_dist_attr(name) + topology = dist_attr.process_mesh.shape + processes = dist_attr.process_mesh.process_ids + dims_mapping = dist_attr.dims_mapping + return _is_about_global_norm( + self.rank_id, + param.shape, + topology, + processes, + dims_mapping, + self.sharding_group, + ) + else: + return param.name in self.rank_to_params[self.rank_id] - def _is_local_param(self, name): + def is_local_param(self, name): + """ + whether the param_name is updated with opt in cur_rank + """ if name not in self.params_name: return False return True - def _is_local_var(self, name): + def _get_dist_attr(self, name): + var = self.block.vars[name] + return self.dist_context.get_tensor_dist_attr_for_program(var) + + def is_local_var_with_dist_attr(self, name): + """ + whether the var_name is belong to cur_rank + """ dist_attr = self._get_dist_attr(name) assert dist_attr is not None return self.rank_id in dist_attr.process_mesh.process_ids @@ -212,6 +238,50 @@ class ClipHelper: op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr) + def _is_pure_data_parallel(self): + for applied_pass in self.pass_context.passes: + if isinstance(applied_pass, ShardingPass): + return False + + groups = get_all_process_groups() + for g in groups: + if g.nranks != self.world_nranks: + return False + + for op in self.block.ops: + if op.type in [ + "c_reduce_sum", + "c_allreduce_sum", + ] and not is_data_parallel_reduce_op(op): + return False + + return True + + def _partition_parameters(self, params): + """ + build rank_id_to_params by the param's numel + to guarantee params in every rank of dp_group as even as possible. + """ + mapping = {} + if not self.pure_data_parallel: + for rank_ in range(self.world_nranks): + mapping[rank_] = [p.name for p in params] + else: + for rank_ in range(self.world_nranks): + mapping[rank_] = [] + sizes = [0] * self.world_nranks + for param in params: + rank = sizes.index(min(sizes)) + mapping[rank].append(param.name) + numel = reduce(lambda x, y: x * y, param.shape) + assert ( + numel > 0 + ), "param [{}] should larger than 0, but it is [{}]".format( + param.name, numel + ) + sizes[rank] += numel + return mapping + @register_pass("auto_parallel_grad_clip") class ClipGradByGloblNormPass(PassBase): @@ -248,14 +318,13 @@ class ClipGradByGloblNormPass(PassBase): # dist_params_grads = _get_params_grads(block) self.clip_helper = ClipHelper( - dist_params_grads, rank_id, block, dist_context + dist_params_grads, rank_id, block, dist_context, context ) self._remove_no_need_ops_vars(block) def _remove_no_need_ops_vars(self, block): removed_op_out_type = [ - 'clip_by_norm', 'squared_l2_norm', 'square', 'reduce_sum', @@ -267,31 +336,40 @@ class ClipGradByGloblNormPass(PassBase): if not is_gradient_clip_op(op): continue - if op.type in removed_op_out_type: + if op.type == 'clip_by_norm': + # remove 'clip_by_norm' op if the param is not updated with opt in current rank input_name = op.input("X")[0] if input_name.find("@GRAD") != -1: - # 'clip_by_norm', 'squared_l2_norm', 'square' param_name = input_name[: input_name.find("@GRAD")] - is_local = self.clip_helper._is_local_param(param_name) - is_calculate = self.clip_helper._is_calcuate_norm( - param_name - ) - if not is_local or ( - not is_calculate and op.type != 'clip_by_norm' - ): + is_local = self.clip_helper.is_local_param(param_name) + if not is_local: + removed_op_idx.add(idx) + removed_tmp_var.update(set(op.output_arg_names)) + + elif op.type in removed_op_out_type: + input_name = op.input("X")[0] + if input_name.find("@GRAD") != -1: + # remove 'squared_l2_norm' and 'square' ops, + # if the param@GRAD in cur_rank does not participate in the calculation of global_norm + param_name = input_name[: input_name.find("@GRAD")] + is_local = self.clip_helper.is_local_param(param_name) + is_calculate = self.clip_helper.is_calcuate_norm(param_name) + if not is_local or not is_calculate: removed_op_idx.add(idx) removed_tmp_var.update(set(op.output_arg_names)) else: - # 'reduce_sum' + # 'reduce_sum' must be behind 'square' if idx - 1 in removed_op_idx: removed_op_idx.add(idx) removed_tmp_var.update(set(op.output_arg_names)) elif op.type == 'elementwise_mul': + # 'elementwise_mul' scale the param@GRAD with global_norm + # remove 'elementwise_mul' op if the param is not updated with opt in current rank input_name = op.input("X")[0] if input_name.find("@GRAD") != -1: param_name = input_name[: input_name.find("@GRAD")] - is_local = self.clip_helper._is_local_param(param_name) + is_local = self.clip_helper.is_local_param(param_name) if not is_local: removed_op_idx.add(idx) if block.ops[idx - 1].type == 'cast': @@ -301,11 +379,14 @@ class ClipGradByGloblNormPass(PassBase): ) elif op.type == 'sum': + # 'sum' op is used to calculate global_norm, and need to filter inputs which is not in cur_rank reserved_vars = [] for input_name in op.input_arg_names: if ( input_name not in removed_tmp_var - and self.clip_helper._is_local_var(input_name) + and self.clip_helper.is_local_var_with_dist_attr( + input_name + ) ): reserved_vars.append(input_name) if not reserved_vars: diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py index 52c209b5bc..baae57b84a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py @@ -31,6 +31,7 @@ def apply_pass(use_sharding=False): strategy.reinit = True if use_sharding: sharding = strategy.sharding + sharding.enable = True sharding.degree = 2 sharding.stage = 2 return strategy -- GitLab