From 70770d0d5e74b2a34713866df6a81c91acc37de2 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 3 Aug 2022 13:49:18 +0800 Subject: [PATCH] [Auto Parallel] Unify gradient synchronization procedure of data parallel (#44815) --- .../distributed/auto_parallel/engine.py | 14 +- .../auto_parallel/operators/common.py | 132 ++++++++++++++++++ .../auto_parallel/operators/dist_default.py | 92 +++--------- .../auto_parallel/operators/dist_embedding.py | 55 +------- .../auto_parallel/operators/dist_matmul.py | 59 ++------ 5 files changed, 171 insertions(+), 181 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 58778042b13..cd76b3dfcd3 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -223,8 +223,8 @@ class Engine: assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset." batch_size = self._user_tuning_config["batch_size"] dataset = self._user_tuning_config["dataset"] - dataset.dp_world_size = self._dp_world_size - dataset.dp_rank = self._dp_rank + dataset.dp_world_size = self._input_split_size + dataset.dp_rank = self._input_split_rank from .tuner.optimization_tuner import OptimizationTuner self._optimization_tuner = OptimizationTuner(self._user_tuning_config, @@ -262,7 +262,7 @@ class Engine: if var.name in block.vars: feed_list.append(block.vars[var.name]) - self._dp_world_size, self._dp_rank = self._get_data_parallel_info( + self._input_split_size, self._input_split_rank = self._get_input_split_info( feed_list[0], self._dist_contexts[mode]) def _parallel(self, mode, all_ranks): @@ -554,8 +554,8 @@ class Engine: batch_size, epochs, steps_per_epoch, - data_parallel_world_size=self._dp_world_size, - data_parallel_rank=self._dp_rank) + data_parallel_world_size=self._input_split_size, + data_parallel_rank=self._input_split_rank) # move read op from the end of program to the start of program new_op_size = len(dist_main_block.ops) @@ -615,8 +615,8 @@ class Engine: fetches = dict(inner_fetch, **usr_fetch) return list(fetches.keys()), fetches - def _get_data_parallel_info(self, var, dist_context): - # get data parallel world size and current data parallel rank + def _get_input_split_info(self, var, dist_context): + # deduce how the input data is split among the cluster from .utils import _get_comm_group, _get_corresponding_rank tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 9f328e4fab6..75002ae4ce1 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -13,7 +13,11 @@ # limitations under the License import abc +import paddle +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..dist_attribute import OperatorDistributedAttribute +from ..utils import _get_comm_group, _get_corresponding_rank +from ..process_group import new_process_group _g_distributed_operator_impl_containers = {} @@ -24,6 +28,16 @@ _g_elementwise_ops = [ BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} +class ParallelMode(): + """ + the parallel mode for communication or auxiliary operator + """ + DataParallel = "auto_parallel/data_parallel" + ModelParallel = "auto_parallel/model_parallel" + PipelineParalel = "auto_parallel/pipeline_paralel" + MoEParallel = "auto_parallel/moe_parallel" + + def is_elementwise_op(op_type): if op_type in _g_elementwise_ops: return True @@ -303,3 +317,121 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): new_op.output(output_name)[0], ref_tensor_dist_attr) ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + +def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): + """ + deduce the data parallel communication group for current operator. + + Args: + dist_ctx (DistributedContext): dist context. + op (Operator): the current (backward) operator which might need. + act_grad_names (list): list of input activation grads variable name to the current operator. + out_grad_names (list): list of the output parameter's grads variable name of the current operator. + rank (int): global ranks index for current process. + """ + dp_group = None + + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + process_mesh = op_dist_attr.process_mesh + mesh_shape = process_mesh.topology + # FIXME Hack for Pipeline Parallelism where the current operator + # not belong to the mesh the current rank belong to. + if rank not in process_mesh.processes: + rank = _get_corresponding_rank(dist_ctx, process_mesh, rank) + + for var_name in act_grad_names: + var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name) + # consider that the variable's shape is None + # TODO utilize the batch_dim attr instead of "0" in future + batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1 + + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + group_ranks = _get_comm_group(process_mesh.processes, + process_mesh.topology, + batch_size_axis, rank) + dp_group = new_process_group(group_ranks) + break + + return dp_group + + +def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names): + """ + insert the allreudce and scale ops for gradients of model + parameters for operator in data parallelism. + + Args: + dist_ctx (DistributedContext): dist context. + op (Operator): the current (backward) operator which might need. + allreduce_var_names (list): list of the parameter's grads variable name in the current operator output. + """ + + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + process_mesh = op_dist_attr.process_mesh + dist_op_context = dist_ctx.dist_op_context + main_block = dist_op_context.work_block + dp_degree = len(dp_group.ranks) + + for var_name in allreduce_var_names: + added_ops = [] + grad_var = main_block.var(var_name) + allreduce_op = main_block.append_op(type='c_allreduce_sum', + inputs={'X': [grad_var]}, + outputs={'Out': [grad_var]}, + attrs={ + 'ring_id': dp_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward + }) + allreduce_op._set_attr('op_namescope', + str('/') + ParallelMode.DataParallel) + added_ops.append(allreduce_op) + + if dist_ctx.gradient_scale: + scale_op = main_block.append_op(type='scale', + inputs={'X': grad_var}, + outputs={'Out': grad_var}, + attrs={ + 'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + scale_op._set_attr('op_namescope', + str('/') + ParallelMode.DataParallel) + added_ops.append(scale_op) + + dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name) + assert dims_mapping is not None, "Unexception: dims_mapping of output [{}] of op [{}] is None".format( + grad_var.name, op_dist_attr.op_type) + # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor + for new_op in added_ops: + new_op_attr = OperatorDistributedAttribute() + new_op_attr.process_mesh = process_mesh + new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) + new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) + dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr) + + +def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, + rank): + """ + conduct the allreudce and scaling(dp size)for gradients of model + parameters for operator in data parallelism. + + Args: + dist_ctx (DistributedContext): dist context. + op (Operator): the current (backward) operator which might need. + act_grad_names (list): list of input activation grads variable name to the current operator. + out_grad_names (list): list of the output parameter's grads variable name of the current operator. + rank (int): global ranks index for current process. + """ + + if len(act_grad_names) == 0 or len(out_grad_names) == 0: + return + + dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank) + + if not dp_group: + return + + sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index d0eba355e7b..08c81c4a306 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -15,6 +15,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container +from .common import gradient_synchronization from .common import register_distributed_operator_impl, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -537,87 +538,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in backward_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) - # check if need gradient allreduce - # if there is a non-gradient & non-parameter input and its batch dimension is splited, - # we need insert gradient allreduce for the gradient of parameter in its output - need_gradient_allreduce = False + # data parallel gradient synchronization + act_grad_names = [] for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and not is_parameter_related( varname, main_block): + act_grad_names.append(varname) - # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op - process_mesh = dist_attr.process_mesh - var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - - # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in process_mesh.processes: - rank_id = _get_corresponding_rank( - ctx, process_mesh, rank_id) - - # NOTE: consider that the variable's shape is None - mesh_shape = process_mesh.topology - batch_size_axis = var_dim_mapping[0] if len( - var_dim_mapping) > 0 else -1 - if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: - need_gradient_allreduce = True - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, - batch_size_axis, rank_id) - dp_degree = len(group_ranks) - dp_group = new_process_group(group_ranks) - break - - if need_gradient_allreduce: - allreduce_vars = [] - for output_name in backward_op.desc.output_names(): - for varname in backward_op.desc.output(output_name): - if varname in kwargs["grad_var_to_var"]: - fwd_name = kwargs["grad_var_to_var"][varname] - if fwd_name not in main_block.vars: - continue - if is_parameter_related(fwd_name, main_block): - allreduce_vars.append(varname) - - if len(allreduce_vars) > 0: - - for varname in allreduce_vars: - added_ops = [] - - grad_var = main_block.var(varname) - allreduce_op = main_block.append_op( - type='c_allreduce_sum', - inputs={'X': [grad_var]}, - outputs={'Out': [grad_var]}, - attrs={ - 'ring_id': dp_group.id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward - }) - added_ops.append(allreduce_op) - - if ctx.gradient_scale: - scale_op = main_block.append_op( - type='scale', - inputs={'X': grad_var}, - outputs={'Out': grad_var}, - attrs={ - 'scale': 1.0 / dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) - added_ops.append(scale_op) - - dims_mapping = ctx.get_tensor_dist_attr_for_program( - grad_var).dims_mapping - process_mesh = dist_attr.process_mesh - for op in added_ops: - op_attr = OperatorDistributedAttribute() - op_attr.process_mesh = process_mesh - op_attr.set_output_dims_mapping(grad_var.name, - dims_mapping) - op_attr.set_input_dims_mapping(grad_var.name, - dims_mapping) - ctx.set_op_dist_attr_for_program(op, op_attr) + out_grad_names = [] + for output_name in backward_op.desc.output_names(): + for varname in backward_op.desc.output(output_name): + if varname in kwargs["grad_var_to_var"]: + fwd_name = kwargs["grad_var_to_var"][varname] + if fwd_name not in main_block.vars: + continue + if is_parameter_related(fwd_name, main_block): + out_grad_names.append(varname) + + gradient_synchronization(ctx, backward_op, act_grad_names, + out_grad_names, rank_id) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 85b8c469aa4..bf12ebb4589 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -16,6 +16,7 @@ from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container +from .common import gradient_synchronization from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -518,56 +519,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op, ctx) - # check if need gradient allreduce - need_gradient_allreduce = False + # data parallel gradient synchronization + act_grad_names = [Ids_var.name] + out_grad_names = [kwargs['W@GRAD'][0]] - process_mesh = dist_attr.process_mesh - var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) - mesh_shape = process_mesh.topology - batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: - need_gradient_allreduce = True - - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, - batch_size_axis, rank_id) - dp_degree = len(group_ranks) - dp_group = new_process_group(group_ranks) - - if need_gradient_allreduce: - added_ops = [] - W_Grad_var = main_block.var(kwargs['W@GRAD'][0]) - allreduce_op = main_block.append_op(type='c_allreduce_sum', - inputs={'X': [W_Grad_var]}, - outputs={'Out': [W_Grad_var]}, - attrs={ - 'ring_id': dp_group.id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward - }) - added_ops.append(allreduce_op) - - if ctx.gradient_scale: - scale_op = main_block.append_op(type='scale', - inputs={'X': W_Grad_var}, - outputs={'Out': W_Grad_var}, - attrs={ - 'scale': 1.0 / dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) - added_ops.append(scale_op) - - main_block._sync_with_cpp() - - dims_mapping = ctx.get_tensor_dist_attr_for_program( - W_Grad_var).dims_mapping - process_mesh = dist_attr.process_mesh - for op in added_ops: - op_attr = OperatorDistributedAttribute() - op_attr.process_mesh = process_mesh - op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) - op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping) - ctx.set_op_dist_attr_for_program(op, op_attr) + gradient_synchronization(ctx, backward_op, act_grad_names, + out_grad_names, rank_id) register_distributed_operator_impl("lookup_table_v2", diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 5ca6366d6b5..f9b5b9a5323 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -19,6 +19,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl +from .common import gradient_synchronization from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -422,55 +423,15 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): matmul_op_desc = copy_op_with_new_input_output(ctx, main_block, backward_op, **kwargs) - # check if need gradient allreduce - need_gradient_allreduce = False - - process_mesh = dist_attr.process_mesh - var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name) - mesh_shape = process_mesh.topology - batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: - need_gradient_allreduce = True - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, batch_size_axis, - rank_id) - dp_degree = len(group_ranks) - dp_group = new_process_group(group_ranks) - - if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block): - added_ops = [] - Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) - allreduce_op = main_block.append_op(type='c_allreduce_sum', - inputs={'X': [Y_Grad_var]}, - outputs={'Out': [Y_Grad_var]}, - attrs={ - 'ring_id': dp_group.id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward - }) - added_ops.append(allreduce_op) - - if ctx.gradient_scale: - scale_op = main_block.append_op(type='scale', - inputs={'X': Y_Grad_var}, - outputs={'Out': Y_Grad_var}, - attrs={ - 'scale': 1.0 / dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) - added_ops.append(scale_op) - - main_block._sync_with_cpp() - - dims_mapping = ctx.get_tensor_dist_attr_for_program( - Y_Grad_var).dims_mapping - process_mesh = dist_attr.process_mesh - for op in added_ops: - op_attr = OperatorDistributedAttribute() - op_attr.process_mesh = process_mesh - op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping) - op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping) - ctx.set_op_dist_attr_for_program(op, op_attr) + # data parallel gradient synchronization + act_grad_names = [X_var.name] + + out_grad_names = [] + if is_parameter_related(Y_var.name, main_block): + out_grad_names = [kwargs['Y@GRAD'][0]] + + gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names, + rank_id) def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): -- GitLab