From 5592f8ad9e3ebd582795c60693684828c95f347d Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 4 Jan 2023 19:06:14 +0800 Subject: [PATCH] [Auto Parallel-Performance] Sharding Comm Optimization (#48604) * remove deps and prior comm * grad comm fuse * add deps for amp&global norm * stage2 broadcast prior deps * stage2 grad overlap * stream_analyzer bugfix * overlap enable * dep op namescope * depend support multiple inputs * check finite deps * stage2 param comm overlap * Set kD2HStream * grad comm hierarchical * grad comm hierarchical * new unitest Co-authored-by: chenruibiao --- .../distributed/auto_parallel/constants.py | 8 +- .../auto_parallel/operators/common.py | 27 +- .../dist_check_finite_and_unscale.py | 2 + .../auto_parallel/parallelizer_v2.py | 10 + .../auto_parallel/process_group.py | 17 +- .../paddle/distributed/auto_parallel/utils.py | 78 +- python/paddle/distributed/passes/__init__.py | 1 + ...uto_parallel_data_parallel_optimization.py | 46 +- .../passes/auto_parallel_grad_clip.py | 14 +- .../passes/auto_parallel_recompute.py | 1 + .../passes/auto_parallel_sharding.py | 930 +++++++++++++++--- ...rallel_supplement_explicit_dependencies.py | 159 +++ .../auto_parallel/sharding_newexe.py | 189 ++++ .../test_sharding_with_newexe.py | 58 ++ .../unittests/auto_parallel/test_strategy.py | 8 +- 15 files changed, 1350 insertions(+), 198 deletions(-) create mode 100644 python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/sharding_newexe.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_sharding_with_newexe.py diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index f0c9655c81e..ee42d4df420 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -90,8 +90,12 @@ SHARDING = "sharding" set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "stage", 1) set_field_default_config(SHARDING, "degree", 8) -set_field_default_config(SHARDING, "overlap_grad_comm", False) -set_field_default_config(SHARDING, "bucket_size_numel", -1) +set_field_default_config(SHARDING, "enable_overlap", False) +set_field_default_config(SHARDING, "param_comm_stream_num", 1) +set_field_default_config(SHARDING, "grad_comm_stream_num", 1) +set_field_default_config(SHARDING, "param_bucket_size_numel", 1) +set_field_default_config(SHARDING, "grad_bucket_size_numel", 1) +set_field_default_config(SHARDING, "enable_hierarchical_comm", False) set_field_default_config(SHARDING, "partition_algor", "greedy_even") set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "tuning_range", []) diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index ef865dc13bb..0d6456dc9d3 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -45,6 +45,15 @@ class ParallelMode: MoEParallel = "auto_parallel/moe_parallel" +class SyncMode: + """ + the synchorization mode for communication or auxiliary operator + """ + + AmpFlagSync = "auto_parallel/amp_flag_synchorization" + GlobalNormSync = "auto_parallel/global_norm_synchorization" + + def is_elementwise_op(op_type): if op_type in _g_elementwise_ops: return True @@ -441,7 +450,7 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names): dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name) assert ( dims_mapping is not None - ), "Unexception: dims_mapping of output [{}] of op [{}] is None".format( + ), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format( grad_var.name, op_dist_attr.op_type ) # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor @@ -502,6 +511,22 @@ def is_data_parallel_reduce_op(op): ) +def is_amp_flag_sync_op(op): + return ( + op.type == "c_allreduce_max" + and op.desc.has_attr("op_namescope") + and SyncMode.AmpFlagSync in op.desc.attr("op_namescope") + ) + + +def is_global_norm_sync_op(op): + return ( + op.type == "c_allreduce_sum" + and op.desc.has_attr("op_namescope") + and SyncMode.GlobalNormSync in op.desc.attr("op_namescope") + ) + + def is_in_backward_phase(dist_ctx): # NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators # in Forward phase and operators in Backward phase (both with op_role=1), which will mislead diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 6a681be1a37..544c02fabad 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -24,6 +24,7 @@ from ..utils import set_dist_op_desc_original_id, set_var_dist_attr from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, + SyncMode, register_distributed_operator_impl, register_distributed_operator_impl_container, ) @@ -166,6 +167,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): OP_ROLE_KEY: OpRole.Optimize, }, ) + allreduce_op._set_attr('op_namescope', str('/') + SyncMode.AmpFlagSync) cast_op2 = main_block.append_op( type='cast', inputs={'X': inf_var_int32}, diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 2ff8f0ee7d1..bccda52cfa4 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -318,6 +318,16 @@ class Parallelizer: [main_program], [startup_program], self._pass_context ) + # deps for newexe + config = {} + config["dist_context"] = self._dist_context + APSED_pass = new_pass( + "auto_parallel_supplement_explicit_dependencies", config + ) + APSED_pass.apply( + [main_program], [startup_program], self._pass_context + ) + # gradient_merge is then train-only optimization if self._mode == "train" and self._strategy.gradient_merge.enable: config = copy.deepcopy(self._strategy.gradient_merge.to_dict()) diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index d7e2638f102..2a07bbe4bb9 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -48,14 +48,16 @@ def clear_all_process_groups(): _g_process_group_map[0] = ProcessGroup(0, []) -def new_process_group(ranks, group_id=None): +def new_process_group(ranks, group_id=None, force_new_group=False): + global _g_process_group_map - # A key constructed from ranks is used for avoiding duplication - new_key = ''.join(map(str, sorted(ranks))) - for pg_id, pg in _g_process_group_map.items(): - cur_key = ''.join(map(str, sorted(pg.ranks))) - if pg_id != 0 and new_key == cur_key: - return pg + if not force_new_group: + # A key constructed from ranks is used for avoiding duplication + new_key = ''.join(map(str, sorted(ranks))) + for pg_id, pg in _g_process_group_map.items(): + cur_key = ''.join(map(str, sorted(pg.ranks))) + if pg_id != 0 and new_key == cur_key: + return pg # If not matching the existing one, construt a new process group num_groups = len(_g_process_group_map) # Note: our process group may interfere with the original implementation @@ -137,7 +139,6 @@ class ProcessGroup: ] strategy.current_endpoint = genv.current_endpoint strategy.nrings = 1 - if core.is_compiled_with_cuda(): place = core.CUDAPlace(genv.device_id) core.NCCLParallelContext(strategy, place).init_with_ring_id( diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 358ccb66857..fef7b168a36 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1184,6 +1184,8 @@ def _get_split_indices( def set_grad_var_shape(program, dist_context): + from paddle.distributed.fleet.meta_optimizers.common import OpRole + from .operators.common import infer_shape block = program.global_block() @@ -1955,6 +1957,9 @@ def set_recompute_segments(model, losses, strategy, program): and hasattr(model.gpt, "checkpoints") ): ckpts = model.gpt.checkpoints + # last recompute segment is not need to recompute + if len(ckpts) > 2: + ckpts.pop() else: ckpts = recompute.checkpoints else: @@ -2189,6 +2194,7 @@ def insert_dependencies_for_two_ops( dist_context, is_recompute=False, sync=False, + op_namescope=None, ): """ dependency: prior_op should be run before posterior_op @@ -2233,49 +2239,74 @@ def insert_dependencies_for_two_ops( [block.var(name) for name in posterior_op.input_arg_names] ) - return insert_dependencies_for_two_vars( + return insert_dependencies_for_vars( block, idx, first_var, second_var, dist_context, OpRole.Backward, - prior_op_mesh, - is_recompute, - sync, + process_mesh=prior_op_mesh, + is_recompute=is_recompute, + sync=sync, + op_namescope=op_namescope, + use_nop=False, ) -def insert_dependencies_for_two_vars( +def insert_dependencies_for_vars( block, idx, - prior_var, - post_var, + prior_vars, + post_vars, dist_context, oprole, process_mesh=None, is_recompute=False, sync=False, + op_namescope=None, + use_nop=False, ): """ - dependency: op that generates prior_var should be run before op that generates post_var + dependency: op that generates prior_vars should be run before op that generates post_vars """ - assert block.has_var(prior_var.name) - assert block.has_var(post_var.name) + + if isinstance(prior_vars, Variable): + prior_vars = [prior_vars] + if isinstance(post_vars, Variable): + post_vars = [post_vars] + for prior_var in prior_vars: + assert block.has_var(prior_var.name) + for post_var in post_vars: + assert block.has_var(post_var.name) + if process_mesh is None: process_mesh = dist_context.get_tensor_dist_attr_for_program( - post_var + post_vars[0] ).process_mesh assert process_mesh is not None - depend_op = block._insert_op_without_sync( - idx, - type='nop', - inputs={ - "X": prior_var, - }, - outputs={"Out": post_var}, - ) + use_nop = True + if use_nop: + depend_op = block._insert_op_without_sync( + idx, + type='nop', + inputs={ + "X": prior_vars, + }, + outputs={"Out": post_vars}, + ) + else: + depend_op = block._insert_op_without_sync( + idx, + type='depend', + inputs={ + "X": post_vars, + "Dep": prior_vars, + }, + outputs={"Out": post_vars}, + ) + # depend_op.desc.set_type("depend") depend_op._set_attr(OP_ROLE_KEY, oprole) # depend_op.desc.set_input("Dep", [first_var.name]) @@ -2284,6 +2315,8 @@ def insert_dependencies_for_two_vars( naive_set_dist_op_attr_for_program_by_mesh( depend_op, process_mesh, dist_context, is_recompute ) + if op_namescope is not None: + depend_op._set_attr('op_namescope', "/{}".format(op_namescope)) if sync: block._sync_with_cpp() @@ -2291,6 +2324,13 @@ def insert_dependencies_for_two_vars( return depend_op +def is_dep_skip_op(op): + if "c_" in op.type: + return True + + return False + + def use_standalone_executor(): return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [ 1, diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 886d29a30b4..00a98908109 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -23,6 +23,7 @@ from .auto_parallel_recompute import * # noqa: F403 from .auto_parallel_quantization import * # noqa: F403 from .auto_parallel_data_parallel_optimization import * # noqa: F403 from .auto_parallel_grad_clip import * # noqa: F403 +from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403 from .cpp_pass import * # noqa: F403 from .ps_trainer_pass import * # noqa: F403 from .ps_server_pass import * # noqa: F403 diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 8cb11270b12..2cfa113a7ae 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -22,7 +22,7 @@ from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.utils import ( find_higher_order_backward_op, get_var_numel, - insert_dependencies_for_two_vars, + insert_dependencies_for_vars, is_forward_op, is_loss_grad_op, is_optimize_op, @@ -153,12 +153,12 @@ class DataParallelOptimizationPass(PassBase): continue assert op.has_attr( "ring_id" - ), "Unexception: comm op [{}] has NOT ring id.".format(str(op)) + ), "Unexpected: comm op [{}] has NOT ring id.".format(str(op)) group = ring_id_to_process_group(op.attr("ring_id")) assert ( group is not None - ), "Unexception: data parallel group of [{}] from op [{}] is None".format( + ), "Unexpected: data parallel group of [{}] from op [{}] is None".format( grad_name, str(op) ) @@ -187,7 +187,7 @@ class DataParallelOptimizationPass(PassBase): not_synchronized_grads.append(grad_name) assert ( len(not_synchronized_grads) == 0 - ), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( + ), "Unexpected: gradients [{}] is scaled BUT NOT synchronized.".format( not_synchronized_grads ) @@ -251,12 +251,12 @@ class DataParallelOptimizationPass(PassBase): ): assert op.has_attr( 'rescale_grad' - ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format( + ), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format( str(op) ) assert ( len(op.input("Grad")) == 1 - ), "Unexception: op [{}] is supported to have only one input grad var.".format( + ), "Unexpected: op [{}] is supported to have only one input grad var.".format( str(op) ) @@ -271,7 +271,7 @@ class DataParallelOptimizationPass(PassBase): assert scaled_grads == set( self._grad_name_to_group_map.keys() - ), "Unexception: gradients [{}] are unscaled.".format( + ), "Unexpected: gradients [{}] are unscaled.".format( set(self._grad_name_to_group_map.keys()) - scaled_grads ) @@ -463,7 +463,7 @@ class DataParallelOptimizationPass(PassBase): group.coalesce_var = group.gradients[0] continue - # create coalecse tensor + # create coalesce tensor group.coalesce_var = block.create_var( name=unique_name.generate( self.coalesce_prefix + '_{}'.format(i) @@ -508,12 +508,10 @@ class DataParallelOptimizationPass(PassBase): for idx in sorted(remove_op_indices, reverse=True): assert ( block.ops[idx].type in remove_op_types - ), "Unexception: try to remove op {}".format( - str(block.ops[idx]) - ) + ), "Unexpected: try to remove op {}".format(str(block.ops[idx])) block._remove_op(idx, False) - # insert coalecse op + # insert coalesce op concated_shapes = [] concated_ranks = [] for grad_ in group.gradients: @@ -596,7 +594,7 @@ class DataParallelOptimizationPass(PassBase): not_sync_coalesces.remove(var_name) assert ( len(not_sync_coalesces) == 0 - ), "Unexception: {} has NOT been add prior Dep before allreduce.".format( + ), "Unexpected: {} has NOT been add prior Dep before allreduce.".format( not_sync_coalesces ) @@ -628,7 +626,7 @@ class DataParallelOptimizationPass(PassBase): assert ( len(not_sync_coalesces) == 0 - ), "Unexception: {} has NOT been add post Dep after allreduce.".format( + ), "Unexpected: {} has NOT been add post Dep after allreduce.".format( not_sync_coalesces ) @@ -642,7 +640,7 @@ class DataParallelOptimizationPass(PassBase): for idx, prior_name, post_name in dep_var_pairs: prior_var = block.var(prior_name) post_var = block.var(post_name) - depend_op = insert_dependencies_for_two_vars( + depend_op = insert_dependencies_for_vars( block, idx, prior_var, @@ -651,9 +649,10 @@ class DataParallelOptimizationPass(PassBase): OpRole.Backward, process_mesh=[ -1 - ], # hack to avoid initialize the dist attr for coalesc var + ], # hack to avoid initialize the dist attr for coalesce var is_recompute=False, sync=False, + op_namescope="data_parallel_overlap_dep", ) depend_op.dist_attr.execution_stream = self.gradient_sync_stream block._sync_with_cpp() @@ -694,16 +693,17 @@ class DataParallelOptimizationPass(PassBase): self._logger.addHandler(log_handler) if len(grad_groups) > 0: + self._logger.info("Data Parallel Optimization: ") self._logger.info( - "origin {} allreduce ops are fused into {} coalecse allreduce ops.".format( + " {} Allreduce ops are fused into {} coalesce allreduce ops.".format( len(self._grad_name_to_group_map.keys()), len(grad_groups) ) ) - self._logger.info("gradient fusing group are following: ") + self._logger.debug("gradient fusing group are following: ") fused_grads = set() for i, group in enumerate(grad_groups): - self._logger.info( - "coalecse gradient [{}] is composed by: {}".format( + self._logger.debug( + "coalesce gradient [{}] is composed by: {}".format( i, [grad.name for grad in group.gradients] ) ) @@ -711,12 +711,14 @@ class DataParallelOptimizationPass(PassBase): individual_grads = set(self._grad_name_to_group_map.keys()) - set( fused_grads ) - self._logger.info( + self._logger.debug( "the following [{}] gradients are not fused: ".format( len(individual_grads) ) ) - self._logger.info("individual gradient {}".format(individual_grads)) + self._logger.debug( + "individual gradient {}".format(individual_grads) + ) class GradientsGroup: diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index d209d13eefd..18b407e1d2c 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -23,11 +23,12 @@ from ..auto_parallel.dist_attribute import ( OperatorDistributedAttribute, TensorDistributedAttribute, ) +from ..auto_parallel.operators.common import SyncMode from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.reshard import Resharder from ..auto_parallel.utils import ( _get_comm_group, - insert_dependencies_for_two_vars, + insert_dependencies_for_vars, is_gradient_clip_op, is_optimize_op, use_standalone_executor, @@ -372,8 +373,9 @@ class ClipGradByGloblNormPass(PassBase): OP_ROLE_KEY: OpRole.Optimize, }, ) + # TODO better regular the usage of op namescope allreduce_op._set_attr( - 'op_namescope', "/gradient_clip_pass" + 'op_namescope', str('/') + SyncMode.GlobalNormSync ) self.clip_helper._init_dist_attr(allreduce_op) @@ -394,15 +396,14 @@ class ClipGradByGloblNormPass(PassBase): prior_op = block.ops[j] break j -= 1 - print("here: ", block.ops[j]) assert ( prior_op is not None - ), "Unexception: ClipByGlobalNorm could not find priory depend op" + ), "Unexpected: ClipByGlobalNorm could not find priory depend op" prior_var = block.vars[prior_op.output_arg_names[0]] assert ( prior_var is not None - ), "Unexception: ClipByGlobalNorm could not find priory depend var" - insert_dependencies_for_two_vars( + ), "Unexpected: ClipByGlobalNorm could not find priory depend var" + insert_dependencies_for_vars( block, idx, prior_var, @@ -414,6 +415,7 @@ class ClipGradByGloblNormPass(PassBase): ], # hack to avoid initialize the dist attr for coalesc var is_recompute=False, sync=False, + op_namescope="grad_clip_fill_constant_dep", ) for varname in removed_tmp_var: diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index d99f335517a..9b32355a4cd 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -474,6 +474,7 @@ class RecomputePass(PassBase): self._dist_context, is_recompute=True, sync=False, + op_namescope="recompute_segment_dep", ) main_program._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 84992aa903b..5f68131d042 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -17,6 +17,7 @@ from functools import reduce import paddle from paddle.distributed.auto_parallel.operators.common import ( + ParallelMode, is_data_parallel_reduce_op, is_parameter_related, ) @@ -25,12 +26,14 @@ from paddle.distributed.auto_parallel.utils import ( _get_comm_group, get_logger, get_var_numel, + insert_dependencies_for_vars, + is_backward_op, + is_dep_skip_op, + is_loss_grad_op, + is_optimize_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, -) -from paddle.distributed.fleet.meta_optimizers.common import ( - is_backward_op, - is_optimizer_op, + use_standalone_executor, ) from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size from paddle.fluid import unique_name @@ -85,15 +88,19 @@ class ShardingPass(PassBase): self.set_attr("stage", None) self.set_attr("sharding_degree", None) # for parallelizer self.set_attr("degree", None) # for parallelizer_v2 - self.set_attr("overlap_grad_comm", None) - self.set_attr("bucket_size_numel", None) + self.set_attr("enable_overlap", None) + self.set_attr("param_comm_stream_num", None) + self.set_attr("grad_comm_stream_num", None) + self.set_attr("param_bucket_size_numel", None) + self.set_attr("grad_bucket_size_numel", None) self.set_attr("partition_algor", None) + self.set_attr("enable_hierarchical_comm", None) self.set_attr("params_grads", []) self.set_attr("global_rank", -1) self.dp_groups = set() self.sharding_infos = [] self.varname_to_sharding_info = {} - self.partial_sharding = False + self.sharding_hybrid_dp = False self.outer_dp_group = None self.shared_params_grads = [] @@ -121,13 +128,20 @@ class ShardingPass(PassBase): "global_rank" ) < 0: return False - if self.get_attr("overlap_grad_comm") is None: + if self.get_attr("enable_overlap") is None: + return False + if self.get_attr("param_comm_stream_num") is None: return False - if self.get_attr("bucket_size_numel") is None: + if self.get_attr("grad_comm_stream_num") is None: + return False + if self.get_attr("param_bucket_size_numel") is None: + return False + if self.get_attr("grad_bucket_size_numel") is None: return False if self.get_attr("partition_algor") is None: return False - + if self.get_attr("enable_hierarchical_comm") is None: + return False return True def _check_conflict(self, other_pass): @@ -140,9 +154,24 @@ class ShardingPass(PassBase): ) self.stage = int(self.get_attr("stage")) self.global_rank = int(self.get_attr("global_rank")) - self.overlap_grad_comm = self.get_attr("overlap_grad_comm") - self.bucket_size_numel = int(self.get_attr("bucket_size_numel")) + self.enable_overlap = self.get_attr("enable_overlap") + self.param_comm_stream_num = int(self.get_attr("param_comm_stream_num")) + self.grad_comm_stream_num = int(self.get_attr("grad_comm_stream_num")) + self.enable_hierarchical_comm = self.get_attr( + "enable_hierarchical_comm" + ) + if self.param_comm_stream_num > 1 or self.grad_comm_stream_num > 1: + assert ( + self.enable_overlap + ), "multiple comm stream need enable_overlap to be True" + self.param_bucket_size_numel = int( + self.get_attr("param_bucket_size_numel") + ) + self.grad_bucket_size_numel = int( + self.get_attr("grad_bucket_size_numel") + ) self.partition_algor = self.get_attr("partition_algor") + params_grads = self.get_attr("params_grads") main_block, startup_block = ( main_program.global_block(), @@ -226,7 +255,9 @@ class ShardingPass(PassBase): # sharding hybrid data parallel: partial sharding param within if dp_group.nranks > self.sharding_world_size: - self.partial_sharding = True + self.sharding_hybrid_dp = True + assert self.param_comm_stream_num < 2 + assert self.grad_comm_stream_num < 2 assert ( len(self.dp_groups) == 1 ), "hybrid sharding and data parallelism are supported only when there is excatly one data parallel group in the network" @@ -402,7 +433,7 @@ class ShardingPass(PassBase): should_removed_optimizer_states = [] for idx, op in reversed(list(enumerate(main_block.ops))): - if not is_optimizer_op(op): + if not is_optimize_op(op): break if op.type in _supported_optimizer_type: @@ -441,7 +472,7 @@ class ShardingPass(PassBase): def _insert_optimizer_broadcasts(self, main_block, startup_block): - if self.stage > 2 or self.bucket_size_numel > 1: + if self.stage > 2 or self.param_bucket_size_numel > 1: return for sharding_info in self.sharding_infos: @@ -460,6 +491,9 @@ class ShardingPass(PassBase): OP_ROLE_KEY: OpRole.Optimize, }, ) + new_op._set_attr( + 'op_namescope', str('/') + ParallelMode.DataParallel + ) param_dist_attr = ( self._dist_context.get_tensor_dist_attr_for_program(param) ) @@ -495,7 +529,7 @@ class ShardingPass(PassBase): input_name = op.input_arg_names[0] base_name = _get_base_name_from_grad_name(input_name) sharding_info = self.varname_to_sharding_info[base_name] - _insert_reduce_op( + reduce_op = _insert_reduce_op( main_block, idx, input_name, @@ -504,12 +538,15 @@ class ShardingPass(PassBase): self._dist_context, ) if ( - not self.partial_sharding + not self.sharding_hybrid_dp or not sharding_info.is_in_local_shard(base_name) ): main_block._remove_op(idx + 1, sync=False) else: op._set_attr("ring_id", self.outer_dp_group.id) + op._set_attr( + 'op_namescope', str('/') + ParallelMode.DataParallel + ) # NOTE: # var@GRAD = sum(var@GRAD@RENAME@0, var@GRAD@RENAME@1) @@ -545,7 +582,7 @@ class ShardingPass(PassBase): not_used_param_nane.append(param_name) for idx, op in reversed(list(enumerate(main_block.ops))): - if is_optimizer_op(op): + if is_optimize_op(op): continue for input_name in op.input_arg_names: @@ -643,22 +680,719 @@ class ShardingPass(PassBase): def _optimization_pass(self, main_program, startup_program): + if self.stage <= 1: + return + + self.grad_coalesce_prefix = 'sharding_coalesce_grad_' + self.param_coalesce_prefix = 'sharding_coalesce_param_' + # NOTE PR#49275 for detail + self.comm_op_scheduling_priority = -1 + + # TODO support multiple sub_blocks + assert ( + len(self.sharding_infos) == 1 + ), "gradient synchronization optimization only support one sharding group right now, but got [{}].".format( + len(self.sharding_infos) + ) + sharding_info = self.sharding_infos[0] + with paddle.static.program_guard(main_program, startup_program): - if self.overlap_grad_comm: - _fuse_overlap_gradient_comm() - # TODO support multiple sub_blocks - if self.bucket_size_numel > 1: + self._gradient_sync_optimization(sharding_info) + # TODO independent the logic of fuse and overlap + # support overlap when no fuse + if self.param_bucket_size_numel > 1: if self.stage == 2: - _fuse_overlap_parameter_comm_stage_two( - self.sharding_infos, - self._dist_context, - fuse_size=self.bucket_size_numel, - ) + self._fuse_overlap_parameter_comm_stage_two(sharding_info) elif self.stage == 3: - _fuse_overlap_parameter_comm_stage_three( - self.sharding_infos, fuse_size=self.bucket_size_numel + self._fuse_overlap_parameter_comm_stage_three(sharding_info) + + def _gradient_sync_optimization(self, sharding_info): + + if self.grad_bucket_size_numel <= 1 and (not self.enable_overlap): + return + + main_block = default_main_program().global_block() + startup_block = default_startup_program().global_block() + coalesce_to_group_map, grad_name_to_group_map = self._group_grads( + main_block, + sharding_info, + ) + self._overlap_grad_comm( + main_block, + sharding_info, + coalesce_to_group_map, + grad_name_to_group_map, + ) + + def _fuse_overlap_parameter_comm_stage_two(self, sharding_info): + + main_block = default_main_program().global_block() + startup_block = default_startup_program().global_block() + + group_to_param_map, param_to_group_map = group_param( + sharding_info, self.param_bucket_size_numel + ) + _logger.info("Sharding Stage2 Optimization:") + _logger.info( + "Param Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format( + self.param_bucket_size_numel, + len(param_to_group_map.keys()), + len(group_to_param_map.keys()), + ) + ) + broadcast_var_to_group_map = {} + + if self.enable_overlap: + # if the communication is cross node, comm will be slow and calc will therefore + # wait for comm. enable multi-comm-stream + # TODO revise me in future + # 1. manager the comm and corresponding stream + # 2. allow more than two streams and open to be config + self.param_comm_group_stream_pairs = [] + ranks = sharding_info.group.ranks + for i in range(self.param_comm_stream_num): + if i == 0: + group = sharding_info.group + else: + group = new_process_group(ranks, force_new_group=True) + # NOTE here stream is just a presentation with different name, + # it is up to executor to create the exact streams given the name. + stream = "sharding_param_comm_stream{}".format(i) + self.param_comm_group_stream_pairs.append( + { + "comm_group": group, + "comm_stream": stream, + } + ) + _logger.info( + "Parameter Communication would use [{}] streams.".format( + self.param_comm_stream_num + ) + ) + self.op_to_stream_idx = {} + + for i, param_group in enumerate(group_to_param_map.keys()): + + assert len(param_group) >= 1 + if len(param_group) > 1: + coalesce_var_name = unique_name.generate( + self.param_coalesce_prefix + str(i) + ) + startup_block.create_var( + name=coalesce_var_name, + dtype=param_group.dtype, + persistable=True, + stop_gradient=True, + ) + param_group.coalesce_var = main_block.create_var( + name=coalesce_var_name, + dtype=param_group.dtype, + persistable=True, + stop_gradient=True, + ) + startup_block.append_op( + type="coalesce_tensor", + inputs={"Input": param_group.vars}, + outputs={ + "Output": param_group.vars, + "FusedOutput": param_group.coalesce_var, + }, + attrs={ + "copy_data": True, + "use_align": True, + "dtype": param_group.dtype, + OP_ROLE_KEY: OpRole.Forward, + }, + ) + else: + param_group.coalesce_var = param_group.vars[0] + _logger.info( + "Bucket[{}] size [{}]MB.".format( + i, + sum([get_var_size(p) for p in param_group.vars]), + ) + ) + _logger.debug( + "Bucket[{}] parameters: {}.".format( + i, + [p.name for p in param_group.vars], + ) + ) + + broadcast_var_to_group_map[ + param_group.coalesce_var.name + ] = param_group + + # TODO revise me to manager stream and comm + comm_stream_idx = i % self.param_comm_stream_num + comm_group = self.param_comm_group_stream_pairs[comm_stream_idx][ + 'comm_group' + ] + comm_stream = self.param_comm_group_stream_pairs[comm_stream_idx][ + 'comm_stream' + ] + new_op = main_block.append_op( + type='c_broadcast', + inputs={'X': param_group.coalesce_var}, + outputs={'Out': param_group.coalesce_var}, + attrs={ + 'ring_id': comm_group.id, + 'root': param_group.rank, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }, + ) + self.op_to_stream_idx[new_op] = comm_stream_idx + new_op._set_attr( + 'op_namescope', str('/') + ParallelMode.DataParallel + ) + if self.enable_overlap: + new_op.dist_attr.execution_stream = comm_stream + new_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + + # NOTE the current dist context lack the presentation for bucket tensor which + # composes many tensor with different dims_mapping. we DO NOT assign dist attr + # for it currently. + + # add dependencies: + # 1. all broadcast depend on its pre collective + # 2. coalesce broadcast add nop to resolute data flow dependencies + dep_map = {} + for i, op in enumerate(main_block.ops): + if is_sharding_param_broadcast_op(op): + broadcast_varname = op.output("Out")[0] + broadcast_var = main_block.vars[broadcast_varname] + param_group = broadcast_var_to_group_map[broadcast_varname] + comm_stream = None + if self.enable_overlap: + comm_stream = op.dist_attr.execution_stream + + # FIXME remove me when upgrade to multi-comm version + if len(dep_map.keys()) < self.param_comm_stream_num: + op = _get_broadcast_first_depend_op(main_block) + prior_var = main_block.vars[op.output("ParamOut")[0]] + else: + pre_op = main_block.ops[i - self.param_comm_stream_num] + assert is_sharding_param_broadcast_op( + pre_op + ), "Unexpected: sharding broadcast pre op should be broadcast." + prior_var = main_block.vars[pre_op.output("Out")[0]] + # broadcast order dependencies + dep_map[i] = [(i, [prior_var], [broadcast_var], comm_stream)] + + if len(param_group.vars) > 1: + # in shard coalesce depend to optimizer + if param_group.is_in_local_shard: + last_grad = param_group.vars[-1] + dep_map[i].append( + (i, [last_grad], [broadcast_var], comm_stream) + ) + # coalesce resolution post deps + dep_map[i].append( + (i + 1, [broadcast_var], param_group.vars, comm_stream) ) + # insert deps + indice = sorted(list(dep_map.keys()), reverse=True) + for i in indice: + for idx, prior_vars, post_vars, comm_stream in dep_map[i][::-1]: + depend_op = insert_dependencies_for_vars( + main_block, + idx, + prior_vars, + post_vars, + self._dist_context, + OpRole.Optimize, + process_mesh=[ + -1 + ], # hack to avoid initialize the dist attr for coalesce var + is_recompute=False, + sync=False, + op_namescope="sharding_stage2_broadcast_dep", + ) + if self.enable_overlap: + depend_op.dist_attr.execution_stream = comm_stream + depend_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + + main_block._sync_with_cpp() + + def _fuse_overlap_parameter_comm_stage_three(self, sharding_info): + pass + + def _group_grads( + self, + block, + sharding_info, + ): + """ + conditions for gradients to be grouped: + 1. group size < grad_bucket_size_numel + 2. same dp group (TODO) + 3. same src rank + 4. same dtype + 5. dependency: grad would NOT be used by other ops within group segment + + main logic: + 1. record coalesce group + 2. record all dp allreduce/reduce op idx + + 3. insert coalesce op + 4. insert coalesce dependency (avoid allocate memory too early) + 5. modify and remove allreduce/reduce op + 6. ensure sharding-dp hybrid parallel logic + + gradients inside same group would be fuse into one coalesce tensor + """ + ops = block.ops + if self.grad_bucket_size_numel < 1: + # numel for transformer layer + # h = 4096 + 1 + # ffn_numel = 2 * (4 * h) * h + # mha_numel = 3 * h * h + h * h + # max_fuse_numel = ffn_numel + mha_numel + self.grad_bucket_size_numel = 1 + + first_backward_op = None + for op in ops: + if is_loss_grad_op(op): + first_backward_op = op + # not backward op, sharding for inference + if first_backward_op is None: + return + first_backward_varname = first_backward_op.output_arg_names[0] + + cur_group = VarGroup(self.grad_bucket_size_numel) + grad_groups = [] + grouped_grad_names = set() + + def op_depend_on_group(op, group): + vars_ = set(op.input_arg_names + op.output_arg_names) + var_names = set([var.name for var in group.vars]) + return len(vars_.intersection(var_names)) > 0 + + # analyze groups + i = 0 + while i < len(ops): + op = ops[i] + if is_data_parallel_reduce_op(op): + assert ( + op.type == "c_reduce_sum" + ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" + + grad_name = op.output_arg_names[0] + param_name = _get_base_name_from_grad_name(grad_name) + rank = sharding_info.get_var_rank(param_name) + grad_var = block.var(grad_name) + + if cur_group.acceptable(grad_var, rank): + assert grad_name not in grouped_grad_names + cur_group.collect(grad_var, rank) + else: + grad_groups.append(cur_group) + cur_group = VarGroup(self.grad_bucket_size_numel) + cur_group.collect(grad_var, rank) + + if len(cur_group.vars) == 1: + cur_group.coalesce_op_idx = i - 1 + # NOTE coalesce dependency: control when allocate memory for gradients + # too early would increase the peak memory requirement, too later would hurt the performance + j = 2 + while is_dep_skip_op(ops[i - j]): + j += 1 + dep_op = ops[i - j] + dep_varname = dep_op.output_arg_names[0] + cur_group.coalesce_dep_varname = dep_varname + + grouped_grad_names.add(grad_name) + cur_group.reduce_op_indices.append(i) + + if self.sharding_hybrid_dp and sharding_info.is_in_local_shard( + param_name + ): + cur_group.is_in_local_shard = True + assert ( + ops[i + 1].type == "c_allreduce_sum" + ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" + assert ( + ops[i + 1].output_arg_names[0] == grad_name + ), "Hybrid Sharding with Data-Parallel should sync same gradient var" + cur_group.allreduce_op_indices.append(i + 1) + i += 1 + elif op_depend_on_group(op, cur_group): + grad_groups.append(cur_group) + cur_group = VarGroup(self.grad_bucket_size_numel) + + i += 1 + # some grad not in this rank may not be used after dp reduced + if len(cur_group.vars) >= 1: + grad_groups.append(cur_group) + + _logger.info("Sharding Gradient Communication Optimization:") + _logger.info( + "Gradient Bucket size is [{}], [{}] Gradients are fused into [{}] Buckets.".format( + self.grad_bucket_size_numel, + len(grouped_grad_names), + len(grad_groups), + ) + ) + + # create coalesce tesnor and record op idx + grad_name_to_group_map = {} + coalesce_to_group_map = {} + modify_reduce_op_map = {} + coalesce_op_map = {} + remove_reduce_op_indices = [] + + for i, group in enumerate(grad_groups): + if len(group.vars) > 1: + group.coalesce_var = block.create_var( + name=unique_name.generate( + self.grad_coalesce_prefix + str(i) + ), + dtype=group.dtype, + persistable=False, + stop_gradient=True, + ) + coalesce_op_map[group.coalesce_op_idx] = group + last_reduce_op_idx = group.reduce_op_indices.pop() + modify_reduce_op_map[last_reduce_op_idx] = group + remove_reduce_op_indices.extend(group.reduce_op_indices) + if group.is_in_local_shard: + last_allreduce_op_idx = group.allreduce_op_indices.pop() + modify_reduce_op_map[last_allreduce_op_idx] = group + remove_reduce_op_indices.extend(group.allreduce_op_indices) + else: + group.coalesce_var = group.vars[0] + for grad in group.vars: + grad_name_to_group_map[grad.name] = group + coalesce_to_group_map[group.coalesce_var.name] = group + + coalesce_op_set = set(coalesce_op_map.keys()) + modify_op_set = set(modify_reduce_op_map.keys()) + remove_op_set = set(remove_reduce_op_indices) + confilct = coalesce_op_set.intersection(modify_op_set) + + assert len(confilct) == 0 + confilct = coalesce_op_set.intersection(remove_op_set) + assert len(confilct) == 0 + confilct = modify_op_set.intersection(remove_op_set) + assert len(confilct) == 0 + + # update block + for idx, op in reversed(list(enumerate(block.ops))): + + if idx in modify_reduce_op_map: + group = modify_reduce_op_map[idx] + grad_name = op.output_arg_names[0] + assert ( + grad_name == group.vars[-1].name + ), "Unexpected: it is supposed to sync [{}] but got [{}]".format( + group.vars[-1].name, grad_name + ) + op._rename_input(grad_name, group.coalesce_var.name) + op._rename_output(grad_name, group.coalesce_var.name) + + if idx in remove_reduce_op_indices: + block._remove_op(idx, sync=False) + + if idx in coalesce_op_map: + group = coalesce_op_map[idx] + first_grad_name = group.vars[0].name + assert ( + first_grad_name in op.output_arg_names + ), "Unexpected: op is supposed to generate grad [{}] but got [{}]".format( + first_grad_name, str(op) + ) + grad_names = [grad.name for grad in group.vars] + + concated_shapes = [] + concated_ranks = [] + for grad_ in group.vars: + shape = grad_.shape + concated_shapes.extend(shape) + concated_ranks.append(len(shape)) + + coalesce_op = block._insert_op_without_sync( + idx, + type="coalesce_tensor", + inputs={"Input": grad_names}, + outputs={ + "Output": grad_names, + "FusedOutput": group.coalesce_var, + }, + attrs={ + "copy_data": False, + "use_align": True, + "dtype": group.dtype, + "concated_shapes": concated_shapes, + "concated_ranks": concated_ranks, + OP_ROLE_KEY: OpRole.Backward, + }, + ) + depend_op = insert_dependencies_for_vars( + block, + idx, + block.var(group.coalesce_dep_varname), + group.coalesce_var, + self._dist_context, + OpRole.Backward, + process_mesh=[ + -1 + ], # hack to avoid initialize the dist attr for coalesce var + is_recompute=False, + sync=False, + op_namescope="sharding_grad_coalesce_dep", + ) + block._sync_with_cpp() + + return coalesce_to_group_map, grad_name_to_group_map + + def _overlap_grad_comm( + self, + block, + sharding_info, + coalesce_to_group_map, + grad_name_to_group_map, + ): + """ + overlap gradient communication with backward & optimizer computation. + + 1. assign gradient communications to grad comm stream + 2. for coalesce gradient communication: + 2.1 insert before communication dependencies + 2.2 insert after communication dependencies only when need + 3. there is not need to add explicit dependencies for non-coalesce gradient communication + + P.S. this overlap pass is ONLY adapted for standalone executor (graph based) and stream awared allocator. + """ + + if not use_standalone_executor() or (not self.enable_overlap): + return + + self.grad_comm_group_stream_pairs = [] + ranks = sharding_info.group.ranks + # NOTE since the gradient synchronization has calculation, there would be computation + # competition between backward calculation. therefore should limit the number of stream used. + for i in range(self.grad_comm_stream_num): + if i == 0: + group = sharding_info.group + else: + group = new_process_group(ranks, force_new_group=True) + # NOTE here stream is just a presentation with different name, + # it is up to executor to create the exact streams given the name. + stream = "sharding_grad_comm_stream{}".format(i) + self.grad_comm_group_stream_pairs.append( + { + "comm_group": group, + "comm_stream": stream, + } + ) + + ops = block.ops + # analyze dependencies + dep_map = {} + reduce_op_count = 0 + grad_comm_op_to_stream_idx = {} + for idx, op in enumerate(ops): + if is_data_parallel_reduce_op(op): + + if op.type == "c_allreduce_sum": + continue + stream_idx = reduce_op_count % self.grad_comm_stream_num + grad_comm_op_to_stream_idx[op] = stream_idx + comm_group = self.grad_comm_group_stream_pairs[stream_idx][ + "comm_group" + ] + comm_stream = self.grad_comm_group_stream_pairs[stream_idx][ + "comm_stream" + ] + + reduce_varname = op.output("Out")[0] + grad_group = coalesce_to_group_map[reduce_varname] + assert grad_group.coalesce_var.name == reduce_varname + + # coalesce deps + if len(grad_group.vars) > 1: + # NOTE should prior vars to be all grads ? + # when the grad_ops' order is random + # prior dep + dep_map[idx] = [ + ( + idx, + grad_group.vars[-1], + grad_group.coalesce_var, + comm_stream, + ) + ] + # post dep + post_idx = idx + 1 + if self.sharding_hybrid_dp and grad_group.is_in_local_shard: + post_idx += 1 + dep_map[idx].append( + ( + post_idx, + grad_group.coalesce_var, + grad_group.vars, + comm_stream, + ) + ) + + # assign stream + op.dist_attr.execution_stream = comm_stream + op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + + op._set_attr("ring_id", comm_group.id) + if self.sharding_hybrid_dp and grad_group.is_in_local_shard: + next_op = ops[idx + 1] + assert next_op.type == "c_allreduce_sum" + assert next_op.output("Out")[0] == reduce_varname + # FIXME hybrid sharding-dp support multi comm & stream in feature + # next_op._set_attr("ring_id", comm_group.id) + next_op.dist_attr.execution_stream = comm_stream + next_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + idx += 1 + + reduce_op_count += 1 + + idx += 1 + + # insert deps + indice = sorted(list(dep_map.keys()), reverse=True) + for i in indice: + for idx, prior_vars, post_vars, comm_stream in dep_map[i][::-1]: + depend_op = insert_dependencies_for_vars( + block, + idx, + prior_vars, + post_vars, + self._dist_context, + OpRole.Backward, + process_mesh=[ + -1 + ], # hack to avoid initialize the dist attr for coalesce var + is_recompute=False, + sync=False, + op_namescope="sharding_grad_comm_dep", + ) + depend_op.dist_attr.execution_stream = comm_stream + depend_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + + # hierarchical grad comm + if self.enable_hierarchical_comm: + # NOTE so far we only support Isomorphic cluster with 8 ranks per node + # TODO unifiy here create communicators + # create communicators + nranks_per_node = 8 + assert self.sharding_world_size % nranks_per_node == 0 + global_group = sharding_info.group + global_ranks = global_group.ranks + relative_idx_in_node = self.global_rank % nranks_per_node + node_idx = self.global_rank // nranks_per_node + inter_node_ranks = [ + rank + for rank in global_ranks + if rank % nranks_per_node == relative_idx_in_node + ] + _logger.info( + "Sharding Gradient Hierarchical Communication Optimization." + ) + _logger.info( + "current global rank idx: {}.".format(self.global_rank) + ) + _logger.info( + "local inter node ranks idx: {}.".format(inter_node_ranks) + ) + assert ( + len(inter_node_ranks) + == self.sharding_world_size // nranks_per_node + ) + intra_node_ranks = [ + rank + for rank in global_ranks + if rank // nranks_per_node == node_idx + ] + assert len(intra_node_ranks) == nranks_per_node + _logger.info( + "local intra node ranks idx: {}.".format(intra_node_ranks) + ) + inter_node_groups = [] + intra_node_groups = [] + for _ in range(self.grad_comm_stream_num): + # TODO re-use one origin communicator + inter_node_groups.append( + new_process_group(inter_node_ranks, force_new_group=True) + ) + intra_node_groups.append( + new_process_group(intra_node_ranks, force_new_group=True) + ) + + # update program + for idx, op in reversed(list(enumerate(block.ops))): + if is_data_parallel_reduce_op(op): + assert op.type == "c_reduce_sum" + grad_comm_stream_idx = grad_comm_op_to_stream_idx[op] + inter_node_group = inter_node_groups[grad_comm_stream_idx] + intra_node_group = intra_node_groups[grad_comm_stream_idx] + + reduce_varname = op.output("Out")[0] + if self.enable_overlap: + comm_stream = op.dist_attr.execution_stream + dst_rank = int(op.attr("root_id")) + + in_peer = False + if dst_rank % nranks_per_node == relative_idx_in_node: + in_peer = True + intra_node_dst = dst_rank % nranks_per_node + + op._set_attr('ring_id', intra_node_group.id) + op._set_attr('root_id', intra_node_dst) + + if in_peer: + inter_node_dst = dst_rank // nranks_per_node + new_op = block._insert_op_without_sync( + idx + 1, + type='c_reduce_sum', + inputs={"X": reduce_varname}, + outputs={ + "Out": reduce_varname, + }, + attrs={ + 'ring_id': inter_node_group.id, + 'root_id': inter_node_dst, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward, + }, + ) + new_op._set_attr( + 'op_namescope', str('/') + ParallelMode.DataParallel + ) + + if self.enable_overlap: + new_op.dist_attr.execution_stream = comm_stream + new_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + + block._sync_with_cpp() + + +def _get_broadcast_first_depend_op(block): + for op in block.ops: + if op.type in _supported_optimizer_type: + return op + + raise Exception("Could not find optimizer op.") + def _insert_init_and_broadcast_op( block, @@ -690,6 +1424,7 @@ def _insert_init_and_broadcast_op( OP_ROLE_KEY: op_role, }, ) + new_op._set_attr('op_namescope', str('/') + ParallelMode.DataParallel) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, broadcast_var_dist_attr.process_mesh, @@ -749,6 +1484,8 @@ def _insert_reduce_op( naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context ) + new_op._set_attr('op_namescope', str('/') + ParallelMode.DataParallel) + return new_op def _get_dp_and_sharding_groups(origin_group, sharding_group_size, rank): @@ -790,7 +1527,7 @@ def _is_param_grad_fp32_cast_op(block, op): def _is_param_fp16_cast_op(block, op, params): - if is_optimizer_op(op): + if is_optimize_op(op): return False if not _is_desired_cast_op(block, op): return False @@ -862,6 +1599,14 @@ def _is_forward_op(op): return op.attr("op_role") == 0 +def is_sharding_param_broadcast_op(op): + return ( + op.type == "c_broadcast" + and op.desc.has_attr("op_namescope") + and ParallelMode.DataParallel in op.desc.attr("op_namescope") + ) + + def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): dp_group = None @@ -975,7 +1720,7 @@ def re_order_program(block, param_grads, dist_context): num_ops = len(block.ops) remove_op_indices = [] # TODO support case when optimizer is not the last op - if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type: + if is_optimize_op(last_op) and last_op.type in _supported_optimizer_type: # record optimizer for idx, op in reversed(list(enumerate(block.ops))): if op.type not in _supported_optimizer_type: @@ -1018,16 +1763,20 @@ def group_param(sharding_info, fuse_size): group_to_param_map = {} param_to_group_map = {} bucket = [] - cur_group = ParameterGroup(fuse_size) + cur_group = VarGroup(fuse_size) for param in sharding_info.params: rank = sharding_info.get_var_rank(param.name) if cur_group.acceptable(param, rank): cur_group.collect(param, rank) else: - cur_group = ParameterGroup(fuse_size) + cur_group = VarGroup(fuse_size) cur_group.collect(param, rank) + cur_group.is_in_local_shard = sharding_info.is_in_local_shard( + param.name + ) + if cur_group in group_to_param_map: group_to_param_map[cur_group].append(param.name) else: @@ -1038,106 +1787,6 @@ def group_param(sharding_info, fuse_size): return group_to_param_map, param_to_group_map -def _fuse_overlap_gradient_comm(): - pass - - -def _fuse_overlap_parameter_comm_stage_two( - sharding_infos, dist_context, fuse_size -): - - assert ( - len(sharding_infos) == 1 - ), "fuse overlap optimization only support one sharding group right now, but got [{}].".format( - len(sharding_infos) - ) - sharding_info = sharding_infos[0] - - main_block = default_main_program().global_block() - startup_block = default_startup_program().global_block() - - group_to_param_map, param_to_group_map = group_param( - sharding_info, fuse_size - ) - _logger.info("Sharding Stage2 Optimization:") - _logger.info( - "Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format( - fuse_size, - len(param_to_group_map.keys()), - len(group_to_param_map.keys()), - ) - ) - for i, group in enumerate(group_to_param_map.keys()): - - assert len(group) >= 1 - if len(group) > 1: - coalesce_var_name = unique_name.generate( - 'coalecse_param_{}'.format(i) - ) - startup_block.create_var( - name=coalesce_var_name, - dtype=group.dtype, - persistable=True, - stop_gradient=True, - ) - group.coalesce_var = main_block.create_var( - name=coalesce_var_name, - dtype=group.dtype, - persistable=True, - stop_gradient=True, - ) - startup_block.append_op( - type="coalesce_tensor", - inputs={"Input": group.params}, - outputs={ - "Output": group.params, - "FusedOutput": group.coalesce_var, - }, - attrs={ - "copy_data": True, - "use_align": True, - "dtype": group.dtype, - OP_ROLE_KEY: OpRole.Forward, - }, - ) - else: - group.coalesce_var = group.params[0] - _logger.info( - "Bucket[{}] size [{}]MB : {}".format( - i, - sum([get_var_size(p) for p in group.params]), - [p.name for p in group.params], - ) - ) - - # TODO Overlap broadcast with opt and next forward - new_op = main_block.append_op( - type='c_broadcast', - inputs={'X': group.coalesce_var}, - outputs={'Out': group.coalesce_var}, - attrs={ - 'ring_id': sharding_info.group.id, - 'root': group.rank, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize, - }, - ) - - # NOTE the current dist context lack the presentation for bucket tensor which - # composes many tensor with different dims_mapping. we assign a fake dist attr - # for it currently. - - -def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size): - - assert ( - len(sharding_infos) == 1 - ), "fuse overlap optimization only support one sharding group right now, but got [{}].".format( - len(sharding_infos) - ) - sharding_info = sharding_infos[0] - - class ShardingInfo(object): def __init__(self, group, rank, params_grads, partition_algor): self.group = group @@ -1188,7 +1837,7 @@ class ShardingInfo(object): param_usage = {x: 0 for x in self.param_names} for op in block.ops: - if is_optimizer_op(op): + if is_optimize_op(op): continue for input_name in op.input_arg_names: if input_name in self.param_names: @@ -1220,14 +1869,19 @@ class ShardingInfo(object): return self.params_grads.get(param_name, None) -class ParameterGroup(object): +class VarGroup(object): def __init__(self, max_size): self.max_siez = max_size self.dtype = None self.rank = -1 self.numel = 0 - self.params = [] + self.vars = [] self.coalesce_var = None + self.coalesce_dep_varname = None + self.coalesce_op_idx = None + self.reduce_op_indices = [] + self.allreduce_op_indices = [] + self.is_in_local_shard = False def acceptable(self, param, rank): if self.numel == 0: @@ -1245,7 +1899,7 @@ class ParameterGroup(object): self.dtype = param.dtype self.rank = rank self.numel += get_var_numel(param) - self.params.append(param) + self.vars.append(param) def __len__(self): - return len(self.params) + return len(self.vars) diff --git a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py new file mode 100644 index 00000000000..0650c4c577e --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.auto_parallel.operators.common import ( + is_amp_flag_sync_op, + is_data_parallel_reduce_op, + is_global_norm_sync_op, +) +from paddle.distributed.auto_parallel.utils import ( + OpRole, + insert_dependencies_for_vars, + use_standalone_executor, +) + +from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type +from .pass_base import PassBase, register_pass + + +def _sharding_pass_applied(pass_ctx): + for applied_pass in pass_ctx.passes: + if isinstance(applied_pass, ShardingPass): + return True + return False + + +# NOTE we add the "auto_parallel" prefix to the pass in order to +# indicate that this pass should obey some constrains by auto_parallel +# for example all ops and vars should has dist attr before and after pass +# should use dist op instead of custom comm op +@register_pass("auto_parallel_supplement_explicit_dependencies") +class AutoParalSupplementDepPass(PassBase): + """ + Functional Concern. + for strategies like amp & global norm, there is a collective communication to sync gradient inforation in every rank. + after partition the gradients to each rank, the order of that collective communication is different in each rank + and might cause hang problem in graph based random order executor. here supplement explicit dependencies for those cases. + + TODO Performance Concern. + global collective will introduce global synchronization which forces the fast workers to wait for slow ones. + therefore we should conduct this collective when all the ranks reach a same stage. + BUT the depend API offered by executor could only ensure "conduct-not-before" but not "conduct-right-after". + Some ranks might call the colletives first than other ranks while they still some local could be performed to wait for slow peers. + IR Pass currently could not have the fully control of time the to perform these global collectives. + """ + + def __init__(self): + super().__init__() + self.set_attr("dist_context", None) + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + + # TODO general this pass for all case. + if not use_standalone_executor or not _sharding_pass_applied(context): + return + + self._dist_context = self.get_attr("dist_context", None) + self.flags_sync_stream = "flags_sync_stream" + main_block = main_program.global_block() + startup_block = startup_program.global_block() + + # last dp grad communication + last_dp_reduce_op_idx = -1 + last_dp_reduce_varname = None + for idx, op in reversed(list(enumerate(main_block.ops))): + if is_data_parallel_reduce_op(op): + last_dp_reduce_op_idx = idx + last_dp_reduce_varname = op.output_arg_names[0] + break + assert last_dp_reduce_op_idx > 0 + assert last_dp_reduce_varname is not None + + # analyze deps for amp & global norm + deps_map = {} + prior_varname = last_dp_reduce_varname + for idx, op in enumerate(main_block.ops): + if is_amp_flag_sync_op(op) or is_global_norm_sync_op(op): + op_namescope = None + if is_amp_flag_sync_op(op): + op_namescope = "amp_flag_sync_dep" + op.dist_attr.execution_stream = self.flags_sync_stream + + elif is_global_norm_sync_op(op): + op_namescope = "global_norm_sync_dep" + deps_map[idx] = (prior_varname, op.input("X")[0], op_namescope) + prior_varname = op.output("Out")[0] + + # analyze deps for check_finite_and_unscale + # ensure it is performed after last backward computation, therefore reduce the + # straggling of the amp-flag-sync + first_check_op = True + for idx, op in enumerate(main_block.ops): + if op.type == "check_finite_and_unscale": + if first_check_op: + last_backward_op = main_block.ops[idx - 1] + prior_varname = last_backward_op.output_arg_names[0] + first_check_op = False + deps_map[idx] = ( + prior_varname, + op.input("Scale")[0], + "check_finite_dep", + ) + + # analyze deps for optimizer + # optimizers order should be fixed to allow broadcast to overlap with optimizer + first_optimizer_op = True + for idx, op in enumerate(main_block.ops): + if op.type in _supported_optimizer_type: + if first_optimizer_op: + first_optimizer_op = False + else: + deps_map[idx] = ( + prior_varname, + op.input("Param")[0], + "optimizer_order_dep", + ) + prior_varname = op.output("ParamOut")[0] + + # insert deps + indice = sorted(list(deps_map.keys()), reverse=True) + for idx in indice: + prior_var = main_block.var(deps_map[idx][0]) + post_var = main_block.var(deps_map[idx][1]) + op_namescope = deps_map[idx][2] + depend_op = insert_dependencies_for_vars( + main_block, + idx, + prior_var, + post_var, + self._dist_context, + OpRole.Optimize, + process_mesh=[ + -1 + ], # hack to avoid initialize the dist attr for coalesc var + is_recompute=False, + sync=False, + op_namescope=op_namescope, + ) + + main_block._sync_with_cpp() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_newexe.py b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_newexe.py new file mode 100644 index 00000000000..ca76daada57 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_newexe.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed.fleet import auto +from paddle.fluid.dygraph.parallel import ParallelEnv + +paddle.enable_static() + + +def apply_pass(use_sharding=False, use_amp=False, use_recompute=False): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + + if use_sharding: + sharding = strategy.sharding + sharding.enable = True + sharding.degree = 2 + sharding.stage = 2 + sharding.enable_overlap = True + sharding.param_comm_stream_num = 2 + sharding.grad_comm_stream_num = 2 + sharding.param_bucket_size_numel = 512 * 512 + sharding.grad_bucket_size_numel = 128 * 128 + sharding.partition_algor = 'use_order' + if use_recompute: + recompute = strategy.recompute + recompute.enable = True + if use_amp: + amp = strategy.amp + amp.enable = True + amp.custom_white_list = [ + 'lookup_table_v2', + 'lookup_table', + 'softmax', + 'layer_norm', + 'gelu', + ] + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', + ] + amp.init_loss_scaling = 32768 + amp.use_fp16_guard = False + amp.use_pure_fp16 = True + amp.use_optimizer_fp16 = False + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestShardingStage2WithNewEXE(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine( + self, use_sharding=False, use_amp=False, use_recompute=False + ): + reset_prog() + + strategy = apply_pass(use_sharding, use_amp, use_recompute) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_param_grad_fuse_overlap(self, program): + num_op = 0 + num_coalesce = 0 + num_reduce = 0 + num_broadcast = 0 + for op in program.global_block().ops: + if op.type == "nop" or op.type == "depend": + num_op += 1 + elif op.type == "coalesce_tensor": + num_coalesce += 1 + elif op.type == "c_reduce_sum": + num_reduce += 1 + elif op.type == "c_broadcast": + num_broadcast += 1 + + if paddle.distributed.get_rank() == 0: + self.assertEqual(num_op, 22) + else: + self.assertEqual(num_op, 54) + + self.assertEqual(num_coalesce, 5) + self.assertEqual(num_reduce, 14) + self.assertEqual(num_broadcast, 2) + + def test_param_grad_fuse_overlap(self): + # dp2 + dp_engine = self.get_engine() + dp_history = dp_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + dp_loss = dp_history.history['loss'][0] + + # sharding2 + sharding_engine = self.get_engine(use_sharding=True) + sharding_history = sharding_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + sharding_loss = sharding_history.history['loss'][0] + + # amp, recompute + amp_recompute_engine = self.get_engine( + use_sharding=False, use_amp=True, use_recompute=True + ) + amp_recompute_history = amp_recompute_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + amp_recompute_loss = amp_recompute_history.history['loss'][0] + + # sharding2, amp, recompute + all_engine = self.get_engine( + use_sharding=True, use_amp=True, use_recompute=True + ) + all_history = all_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + all_loss = all_history.history['loss'][0] + + self.check_param_grad_fuse_overlap(sharding_engine.main_program) + np.testing.assert_allclose( + dp_loss, sharding_loss, rtol=1e-05, atol=1e-08 + ) + np.testing.assert_allclose( + amp_recompute_loss, all_loss, rtol=1e-05, atol=1e-08 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_sharding_with_newexe.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_sharding_with_newexe.py new file mode 100644 index 00000000000..91ffae423c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_sharding_with_newexe.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + +os.environ["FLAGS_CONVERT_GRAPH_TO_PROGRAM"] = str(1) +os.environ["FLAGS_add_dependency_for_communication_op"] = 'false' + + +class TestShardingWithNewEXE(unittest.TestCase): + def test_stage2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "sharding_newexe.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 0b41e323ffd..0cc83160908 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -52,9 +52,13 @@ class TestStrategy(unittest.TestCase): self.assertEqual(sharding.enable, False) self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.degree, 8) - self.assertAlmostEqual(sharding.overlap_grad_comm, False) - self.assertAlmostEqual(sharding.bucket_size_numel, -1) + self.assertAlmostEqual(sharding.enable_overlap, False) + self.assertAlmostEqual(sharding.param_comm_stream_num, 1) + self.assertAlmostEqual(sharding.grad_comm_stream_num, 1) self.assertAlmostEqual(sharding.partition_algor, "greedy_even") + self.assertAlmostEqual(sharding.param_bucket_size_numel, 1) + self.assertAlmostEqual(sharding.grad_bucket_size_numel, 1) + self.assertAlmostEqual(sharding.enable_hierarchical_comm, False) self.assertEqual(sharding.enable_tuning, False) self.assertEqual(sharding.tuning_range, []) -- GitLab