From 229befc80d246dba01af42965b1861d863f42247 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 23 Aug 2022 12:01:36 +0800 Subject: [PATCH] [Auto Parallel] Data Parallel Comm & Calc Overlap Optimization (#45173) * bugfix * remove scaling * support rescale_grad opt * add unitest --- .../distributed/auto_parallel/engine.py | 6 +- ...uto_parallel_data_parallel_optimization.py | 81 +++++++++++++++++-- .../contrib/mixed_precision/fp16_utils.py | 5 +- 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index e8a57ade0f1..35ff882491a 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -189,8 +189,9 @@ class Engine: serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() # FIXME to support grad clip - with static.program_guard(serial_main_prog, serial_startup_prog), \ - utils.unique_name.guard(): + # with static.program_guard(serial_main_prog, serial_startup_prog), \ + # utils.unique_name.guard(): + with static.program_guard(serial_main_prog, serial_startup_prog): inputs_spec = self.inputs_spec labels_spec = self.labels_spec if self.labels_spec else [] inputs = [s._create_feed_layer() for s in inputs_spec] @@ -440,6 +441,7 @@ class Engine: for epoch in range(epochs): train_logs = {"epoch: {:d} ": epoch} for step, _ in enumerate(train_dataloader): + outs = self._executor.run(self.main_program, fetch_list=fetch_list, use_program_cache=use_cache, diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 9538364bf89..d91fe644c98 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -16,6 +16,7 @@ from collections import OrderedDict import paddle from paddle.fluid.framework import default_main_program +from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group from .pass_base import PassBase, PassType, register_pass @@ -26,6 +27,9 @@ __rescale_grad_supported_opts__ = [ 'merge_momentum' ] +# a heuristic number +__max_stream_num_allow__ = 16 + @register_pass("auto_parallel_data_parallel_optimization") class DataParallelOptimizationPass(PassBase): @@ -71,7 +75,7 @@ class DataParallelOptimizationPass(PassBase): with paddle.static.program_guard(main_program, startup_program): self._analyze_program() self._prune_grad_scaling() - self._overlap_comm() + self._calc_comm_overlap() self._fuse_allreduce() def _prune_grad_scaling(self): @@ -86,14 +90,18 @@ class DataParallelOptimizationPass(PassBase): self._remove_grad_scaling() - def _overlap_comm(self): - pass + def _calc_comm_overlap(self): + if not self._could_be_overlap(): + return + self._calc_overlap_comms() + self._update_wait_comms() def _fuse_allreduce(self): pass def _analyze_program(self): """ + build two maps {param_grad_name: data_parallel_group} {pdata_parallel_group: aram_grad_name} """ @@ -103,8 +111,9 @@ class DataParallelOptimizationPass(PassBase): scaled_grads = [] for op in ops: + grad_name = op.output_arg_names[0] + if is_data_parallel_reduce_op(op): - grad_name = op.output_arg_names[0] if grad_name in self._grad_name_to_group_map: continue assert op.has_attr( @@ -123,7 +132,6 @@ class DataParallelOptimizationPass(PassBase): self._group_to_grad_name_map[group].append(grad_name) elif is_data_parallel_scale_op(op): - grad_name = op.output_arg_names[0] scaled_grads.append(grad_name) # TODO support multiple optimizers in on network in future. @@ -206,3 +214,66 @@ class DataParallelOptimizationPass(PassBase): assert scaled_grads == set(self._grad_name_to_group_map.keys( )), "Unexception: gradients [{}] are unscaled.".format( set(self._grad_name_to_group_map.keys()) - scaled_grads) + + def _could_be_overlap(self): + # NOTE current different nccl comm will use different cuda stream + # so if there too many dp group there will be too many stream need to be + # created and sync. + # revise here when framework support custom stream in static mode. + num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys())) + if num_dp_comm_stream > __max_stream_num_allow__: + return False + + return True + + def _calc_overlap_comms(self): + # TODO support InterpreterCore executor for overlap. + # InterpreterCore has a different logic for overlapping + # which is different from use_calc_stream + block = default_main_program().global_block() + ops = block.ops + + # comm wait calc to finish + for idx, op in reversed(list(enumerate(block.ops))): + if is_data_parallel_reduce_op(op): + assert op.has_attr('use_calc_stream') + assert op.has_attr('ring_id') + + op._set_attr('use_calc_stream', False) + ring_id = op.attr("ring_id") + + block._insert_op_without_sync(idx, + type='c_wait_compute', + inputs={'X': []}, + outputs={'Out': []}, + attrs={ + 'op_role': OpRole.Backward, + 'ring_id': ring_id + }) + + block._sync_with_cpp() + + def _update_wait_comms(self): + + block = default_main_program().global_block() + ops = block.ops + + # update wait comm to finish + first_optimize_op_idx = -1 + for idx, op in enumerate(ops): + if is_optimize_op(op): + first_optimize_op_idx = idx + break + + assert first_optimize_op_idx > -1, "Unexception: not found optimizer op in program" + + for group in self._group_to_grad_name_map.keys(): + ring_id = group.id + block._insert_op_without_sync(first_optimize_op_idx, + type='c_wait_comm', + inputs={'X': []}, + outputs={'Out': []}, + attrs={ + 'op_role': OpRole.Backward, + 'ring_id': ring_id + }) diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index b23c94c7e49..e35dc901c83 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -542,9 +542,12 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() var_scope = scope if scope else global_scope() + print( + "======================cast_parameters_to_fp16==============================" + ) for param in all_parameters: if param.name in fp16_var_names: - _logger.debug("---- cast {} to fp16 dtype ----".format(param.name)) + print("---- cast {} to fp16 dtype ----".format(param.name)) param_t = var_scope.find_var(param.name).get_tensor() data = np.array(param_t) param_t.set(np.float16(data), place) -- GitLab