diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c31642a9e2af315ff0ca3f2d4d6a624baae8733e..280868773cdc3c89170ee42a7151bcc22317f697 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1410,6 +1410,9 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( def naive_set_dist_op_attr_for_program_by_mesh( new_op, process_mesh, ctx, is_recompute=False ): + # hack to skip coalesce var for dist attr + if not is_recompute: + return assert process_mesh is not None new_op_dist_attr = OperatorDistributedAttribute() @@ -2129,13 +2132,13 @@ def insert_dependencies_for_two_ops( block, idx, prior_op, - posterior, + posterior_op, dist_context, is_recompute=False, sync=False, ): """ - dependency: prior_op should be run before posterior + dependency: prior_op should be run before posterior_op """ assert ( @@ -2144,15 +2147,15 @@ def insert_dependencies_for_two_ops( str(prior_op) ) assert ( - len(posterior.input_arg_names) >= 1 + len(posterior_op.input_arg_names) >= 1 ), "second op of dependency should at least have one input. [{}]".format( - str(posterior) + str(posterior_op) ) prior_op_mesh = dist_context.get_op_dist_attr_for_program( prior_op ).process_mesh posterior_mesh = dist_context.get_op_dist_attr_for_program( - posterior + posterior_op ).process_mesh assert ( prior_op_mesh == posterior_mesh @@ -2171,25 +2174,72 @@ def insert_dependencies_for_two_ops( [block.var(name) for name in prior_op.output_arg_names] ) second_var = _select_best_depend_var( - [block.var(name) for name in posterior.input_arg_names] + [block.var(name) for name in posterior_op.input_arg_names] ) + return insert_dependencies_for_two_vars( + block, + idx, + first_var, + second_var, + dist_context, + OpRole.Backward, + prior_op_mesh, + is_recompute, + sync, + ) + + +def insert_dependencies_for_two_vars( + block, + idx, + prior_var, + post_var, + dist_context, + oprole, + process_mesh=None, + is_recompute=False, + sync=False, +): + """ + dependency: op that generates prior_var should be run before op that generates post_var + """ + assert block.has_var(prior_var.name) + assert block.has_var(post_var.name) + if process_mesh is None: + process_mesh = dist_context.get_tensor_dist_attr_for_program( + post_var + ).process_mesh + assert process_mesh is not None + depend_op = block._insert_op_without_sync( idx, type='nop', inputs={ - "X": first_var, + "X": prior_var, }, - outputs={"Out": second_var}, + outputs={"Out": post_var}, ) # depend_op.desc.set_type("depend") - depend_op._set_attr(OP_ROLE_KEY, OpRole.Backward) + depend_op._set_attr(OP_ROLE_KEY, oprole) # depend_op.desc.set_input("Dep", [first_var.name]) # self.desc.set_output(out_proto.name, out_arg_names) naive_set_dist_op_attr_for_program_by_mesh( - depend_op, prior_op_mesh, dist_context, is_recompute + depend_op, process_mesh, dist_context, is_recompute ) if sync: block._sync_with_cpp() + + return depend_op + + +def use_standalone_executor(): + return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [ + 1, + '1', + True, + 'True', + 'true', + ] diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 601cd31948b3fc754ed6417fce9969f22f1212b6..47759484a66ee294f804fc0d402bac6f7dbc3b06 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -27,8 +27,11 @@ from paddle.distributed.auto_parallel.utils import ( find_higher_order_backward_op, is_loss_grad_op, is_optimize_op, + is_forward_op, ring_id_to_process_group, get_var_numel, + use_standalone_executor, + insert_dependencies_for_two_vars, ) # add new optimizers supporting rescale_grad here @@ -87,16 +90,20 @@ class DataParallelOptimizationPass(PassBase): self.dist_context = self.get_attr("dist_context") self.global_rank = int(self.get_attr("global_rank")) self.use_sharding = self.get_attr("use_sharding") + self.coalesce_prefix = 'coalesce_grad' + if use_standalone_executor(): + self.gradient_sync_stream = "gradient_sync_stream" with paddle.static.program_guard(main_program, startup_program): self._analyze_program() + # TODO refactor here to first fuse then overlap if self.is_data_parallel_applied(): self._prune_grad_scaling() self._calc_comm_overlap() grad_group = self._fuse_allreduce() - - # self.summary(grad_group) + self._add_dependencies(grad_group) + self.summary(grad_group) def _prune_grad_scaling(self): @@ -284,7 +291,6 @@ class DataParallelOptimizationPass(PassBase): # InterpreterCore has a different logic for overlapping # which is different from use_calc_stream block = default_main_program().global_block() - ops = block.ops # comm wait calc to finish for idx, op in reversed(list(enumerate(block.ops))): @@ -294,7 +300,6 @@ class DataParallelOptimizationPass(PassBase): op._set_attr('use_calc_stream', False) ring_id = op.attr("ring_id") - block._insert_op_without_sync( idx, type='c_wait_compute', @@ -307,8 +312,10 @@ class DataParallelOptimizationPass(PassBase): def _calc_wait_comms(self): + if use_standalone_executor(): + return + block = default_main_program().global_block() - ops = block.ops # NOTE the naive overlap implement in static hybird parallel only sync comm stream # at the end of Backward phase, based on a strong constraint that @@ -325,7 +332,7 @@ class DataParallelOptimizationPass(PassBase): ring_id_to_un_sync_grad_map[group.id] = [] # analyze the where need to sync - for i, op in enumerate(ops): + for i, op in enumerate(block.ops): if is_data_parallel_reduce_op(op): ring_id = op.attr("ring_id") grad_name = op.output_arg_names[0] @@ -365,6 +372,7 @@ class DataParallelOptimizationPass(PassBase): outputs={'Out': []}, attrs={'op_role': OpRole.Backward, 'ring_id': ring_id}, ) + block._sync_with_cpp() def _could_be_fuse(self): # TODO support gradient fuse higher order gradient. @@ -404,8 +412,6 @@ class DataParallelOptimizationPass(PassBase): def collect_group(cur_group, grad_var, ring_id, i): if len(cur_group.gradients) == 0: cur_group = None - elif len(cur_group.gradients) == 1: - grouped_grad_names.remove(cur_group.gradients[0].name) else: cur_group.finalize() grad_groups.append(cur_group) @@ -451,9 +457,16 @@ class DataParallelOptimizationPass(PassBase): for i, group in enumerate(grad_groups[::-1]): + # skip unfused big tensor + if len(group.gradients) <= 1: + group.coalesce_var = group.gradients[0] + continue + # create coalecse tensor group.coalesce_var = block.create_var( - name=unique_name.generate('coalecse_grad_{}'.format(i)), + name=unique_name.generate( + self.coalesce_prefix + '_{}'.format(i) + ), dtype=group.dtype, persistable=False, stop_gradient=True, @@ -497,7 +510,7 @@ class DataParallelOptimizationPass(PassBase): ), "Unexception: try to remove op {}".format( str(block.ops[idx]) ) - block._remove_op(idx) + block._remove_op(idx, False) # insert coalecse op concated_shapes = [] @@ -529,6 +542,141 @@ class DataParallelOptimizationPass(PassBase): block._sync_with_cpp() # TODO update dist attr + def _add_dependencies(self, grad_groups): + # NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based + # multiple stream executor(standalone exe). This function just for standalone exe. Refactor here + # in future when only one executor stay. + + if not use_standalone_executor() or len(grad_groups) == 0: + return + block = default_main_program().global_block() + + # Build maps + vars_to_coalesce_map = {} + coalesce_to_vars_map = {} + + for group in grad_groups: + grad_names = [] + coalesce_name = group.coalesce_var.name + for grad in group.gradients: + vars_to_coalesce_map[grad.name] = coalesce_name + grad_names.append(grad.name) + coalesce_to_vars_map[coalesce_name] = grad_names + + # analyze dependencies + # Record ONLY the last grad that generated before allreduce + # NOTE need to be update when we allow multiple calc stream for backward calc + not_sync_coalesces = [] + prior_allreduce_deps = {} + for idx, op in reversed(list(enumerate(block.ops))): + if is_forward_op(op): + break + if is_optimize_op(op): + continue + + if is_data_parallel_reduce_op(op): + coalesce_var_name = op.output_arg_names[0] + + # NOTE only add extra deps for fused tensor, other tensor rely on + # data flow analysis of executor. + if self.coalesce_prefix in coalesce_var_name: + prior_allreduce_deps[coalesce_var_name] = [ + idx, + None, + coalesce_var_name, + ] + not_sync_coalesces.append(coalesce_var_name) + continue + + for out_name in op.output_arg_names: + var_name = vars_to_coalesce_map.get(out_name, None) + if var_name in not_sync_coalesces: + prior_allreduce_deps[var_name][1] = out_name + not_sync_coalesces.remove(var_name) + assert ( + len(not_sync_coalesces) == 0 + ), "Unexception: {} has NOT been add prior Dep before allreduce.".format( + not_sync_coalesces + ) + + # Record ONLY the first grad that used after allreduce + # NOTE need to be update when we allow multiple calc stream for backward calc + not_sync_coalesces = [] + post_allreduce_deps = {} + for idx, op in enumerate(block.ops): + if is_forward_op(op): + continue + + if is_data_parallel_reduce_op(op): + coalesce_var_name = op.input_arg_names[0] + if self.coalesce_prefix in coalesce_var_name: + post_allreduce_deps[coalesce_var_name] = [ + None, + coalesce_var_name, + None, + ] + not_sync_coalesces.append(coalesce_var_name) + continue + + for out_name in op.input_arg_names: + var_name = vars_to_coalesce_map.get(out_name, None) + if var_name in not_sync_coalesces: + post_allreduce_deps[var_name][0] = idx + post_allreduce_deps[var_name][2] = out_name + not_sync_coalesces.remove(var_name) + + assert ( + len(not_sync_coalesces) == 0 + ), "Unexception: {} has NOT been add post Dep after allreduce.".format( + not_sync_coalesces + ) + + # Update program IR insert dependencise op + dep_var_pairs = [] + for deps in [prior_allreduce_deps, post_allreduce_deps]: + for pair in deps.values(): + dep_var_pairs.append(pair) + + dep_var_pairs.sort(key=lambda x: x[0], reverse=True) + for idx, prior_name, post_name in dep_var_pairs: + prior_var = block.var(prior_name) + post_var = block.var(post_name) + depend_op = insert_dependencies_for_two_vars( + block, + idx, + prior_var, + post_var, + self.dist_context, + OpRole.Backward, + process_mesh=[ + -1 + ], # hack to avoid initialize the dist attr for coalesc var + is_recompute=False, + sync=False, + ) + depend_op.dist_attr.execution_stream = self.gradient_sync_stream + block._sync_with_cpp() + + # remove naive synchronization & assign allreduce stream + def remove_cond(op): + if op.type != "c_wait_compute": + return False + if len(op.input_arg_names) != 0: + return False + if len(op.output_arg_names) != 0: + return False + return True + + for idx, op in reversed(list(enumerate(block.ops))): + if is_data_parallel_reduce_op(op): + op._set_attr('use_calc_stream', True) + op.dist_attr.execution_stream = self.gradient_sync_stream + + if remove_cond(op): + block._remove_op(idx, sync=False) + + block._sync_with_cpp() + def summary(self, grad_groups=[]): # TODO: add logger module import logging diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 73432baa1d3c3d17d7aa8f9f22b3b9a9b3cb5309..a475f8e0ac317e3ac82d731bd1b50ee8f54a386c 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -26,6 +26,8 @@ from ..auto_parallel.utils import ( OP_ROLE_KEY, OpRole, _get_comm_group, + insert_dependencies_for_two_vars, + use_standalone_executor, ) from ..auto_parallel.dist_attribute import ( TensorDistributedAttribute, @@ -334,6 +336,7 @@ class ClipGradByGloblNormPass(PassBase): if op.type == 'sqrt': input_name = op.input("X")[0] input_var = block.vars[input_name] + insert_leaf_fill_constant_node = False if paddle.distributed.get_world_size() > 1: offset = 0 if input_name in removed_tmp_var: @@ -356,6 +359,7 @@ class ClipGradByGloblNormPass(PassBase): ) offset += 1 self.clip_helper._init_dist_attr(fill_constant_op) + insert_leaf_fill_constant_node = True allreduce_op = block._insert_op( idx + offset, @@ -373,6 +377,45 @@ class ClipGradByGloblNormPass(PassBase): ) self.clip_helper._init_dist_attr(allreduce_op) + if ( + use_standalone_executor + and insert_leaf_fill_constant_node + ): + + # NOTE add naive deps for global norm sync in graph exe + j = idx - 1 + prior_op = None + while j > 0: + op_type = block.ops[j].type + if op_type in [ + 'update_loss_scaling', + 'check_finite_and_unscale', + ] or op_type.endswith("_grad"): + prior_op = block.ops[j] + break + j -= 1 + print("here: ", block.ops[j]) + assert ( + prior_op is not None + ), "Unexception: ClipByGlobalNorm could not find priory depend op" + prior_var = block.vars[prior_op.output_arg_names[0]] + assert ( + prior_var is not None + ), "Unexception: ClipByGlobalNorm could not find priory depend var" + insert_dependencies_for_two_vars( + block, + idx, + prior_var, + input_var, + self.clip_helper.dist_context, + OpRole.Optimize, + process_mesh=[ + -1 + ], # hack to avoid initialize the dist attr for coalesc var + is_recompute=False, + sync=False, + ) + for varname in removed_tmp_var: block._remove_var(varname, sync=False) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 72a116a5eb3afeea1afbb988ad6f9cedd174fc70..23fb73f10eff716bcc08fd3e17656c5727893e2f 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -203,6 +203,7 @@ class RecomputeState(ProgramStats): if cur_op.attr("fix_seed") is False else int(cur_op.attr("seed")) ) + # TODO add dependency for seed op to ensure it be issued just before recompute. seed_op = self._block._insert_op_without_sync( index=cur_op.idx, type="seed", @@ -490,6 +491,7 @@ class RecomputePass(PassBase): prior_op, posterior_op, self._dist_context, + is_recompute=True, sync=False, ) main_program._sync_with_cpp() diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py index ee61a156757ba416c1e151d37aab183fc405e1e3..5a6486991dc9da4818d15420349510058bed9981 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random import sys import unittest @@ -24,6 +25,9 @@ import paddle.distributed.fleet as fleet from paddle.distributed.auto_parallel.dist_context import ( get_default_distributed_context, ) +from paddle.distributed.auto_parallel.operators.common import ( + is_data_parallel_reduce_op, +) from paddle.distributed.passes import PassContext, new_pass sys.path.append("..") @@ -116,5 +120,63 @@ class TestDataParallelPassWithScale2(TestDataParallelPassWithScale1): return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data +class TestDataParallelPassWithStandaloneEXE(TestDataParallelPassWithScale1): + def init(self): + if paddle.is_compiled_with_cuda(): + os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM'] = "1" + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.rtol = 1e-5 + self.atol = 1e-8 + # NOTE a hack to compare pass apply or not, since there is no + # setting of this pass in dist_strategy + self._apply_pass = False + + rank = paddle.distributed.get_rank() + paddle.seed(rank + 2021) + random.seed(rank + 2021) + np.random.seed(rank + 2021) + + # test scaling with optimizer rescale_grad + def get_model(self, place, batch_size, sequence_len, vocab_size): + + ( + dist_main_prog, + dist_startup_prog, + data_holder, + [loss], + gen_data, + ) = self.get_gpt_model( + 'dp', + place, + batch_size, + sequence_len, + vocab_size, + optimizer='LarsMomentum', + ) + if self._apply_pass: + config = {} + config["dist_context"] = get_default_distributed_context() + config["global_rank"] = paddle.distributed.get_rank() + dp_pass = new_pass( + "auto_parallel_data_parallel_optimization", config + ) + dp_pass.apply([dist_main_prog], [dist_startup_prog], PassContext()) + + ops = dist_main_prog.global_block().ops + allreduce_op_idx = -1 + for idx in range(len(ops)): + if is_data_parallel_reduce_op(ops[idx]): + allreduce_op_idx = idx + break + assert allreduce_op_idx > 0 + allreduce_op = ops[allreduce_op_idx] + assert allreduce_op.attr('use_calc_stream') is True + assert allreduce_op.dist_attr.execution_stream is not None + assert ops[allreduce_op_idx - 1].type == "nop" + assert ops[allreduce_op_idx + 1].type == "nop" + + return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data + + if __name__ == "__main__": unittest.main()