From 07f33da94722c9ddbae4f85a2004b0f3b79968d4 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 13 Jul 2022 17:00:26 +0800 Subject: [PATCH] [Auto parallel] Accelerate procedure of partitioning and generating dist graphs (#44224) * avoid sync with cpp in partition op * delay eval & predict mode * bugfix for gradient merge pass --- .../distributed/auto_parallel/engine.py | 33 +++++++++++++++---- .../dist_check_finite_and_unscale.py | 4 +-- .../auto_parallel/operators/dist_default.py | 12 ++----- .../auto_parallel/operators/dist_embedding.py | 5 +-- .../dist_fill_constant_batch_size_like.py | 1 - .../auto_parallel/operators/dist_matmul.py | 6 +--- .../auto_parallel/operators/dist_pnorm.py | 8 ++--- .../auto_parallel/operators/dist_reduce_p.py | 3 +- .../auto_parallel/operators/dist_reshape.py | 12 ++----- .../operators/dist_update_loss_scaling.py | 3 +- .../auto_parallel/parallelizer_v2.py | 33 ++++++++++++++++++- .../passes/auto_parallel_gradient_merge.py | 2 +- 12 files changed, 71 insertions(+), 51 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 4fd1ca3114a..1e1e37b4435 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -85,6 +85,11 @@ class Engine: self._feed_vars = {} self._fetch_vars = {} self._planners = {} + self._mode_init_states = { + "train": False, + "eval": False, + "predict": False + } self._dygraph_mode = False def prepare(self, @@ -101,6 +106,7 @@ class Engine: " or `paddle.fluid.optimizer.Optimizer`." ) self._optimizer = optimizer + self._all_ranks = all_ranks if loss and not isinstance(loss, paddle.nn.Layer) and not callable(loss): @@ -116,22 +122,23 @@ class Engine: metric.__class__.__name__) self._metrics = to_list(metrics) self._gradient_scale = gradient_scale - self._planned_mode = None - self._modes = ['train', 'eval', 'predict'] + self._prepare_single_mode("train") - # Build program and do auto parallel process - for mode in self._modes: - # Build forward program - self._build(mode) + def _prepare_single_mode(self, mode): + self._modes = [mode] + self._build(self._modes[0]) + # Do auto parallel process for mode in self._modes: # Do the planning process self._plan(mode) for mode in self._modes: # Do the parallel process - self._parallel(mode, all_ranks) + self._parallel(mode, self._all_ranks) + # Init comm and startup program self._initialize(mode) + self._mode_init_states[mode] = True def _build(self, mode): @@ -432,6 +439,12 @@ class Engine: return_numpy=True): # TODO: callbacks # TODO: evaluate after training + + if not self._mode_init_states['train']: + raise Exception( + "train program is not initialized yet, please call engine.prepare() before calling fit() funtion." + ) + self.mode = 'train' assert self.mode in self._dist_main_progs, \ "train model is not ready, please call `engine.prepare()` first." @@ -467,6 +480,9 @@ class Engine: use_program_cache=False, return_numpy=True): self.mode = 'eval' + if not self._mode_init_states[self.mode]: + self._prepare_single_mode(self.mode) + assert self.mode in self._dist_main_progs, \ "eval model is not ready, please call `engine.prepare()` first." eval_dataloader = self._create_dataloader(eval_data, batch_size) @@ -509,6 +525,9 @@ class Engine: use_program_cache=False, return_numpy=True): self.mode = 'predict' + if not self._mode_init_states[self.mode]: + self._prepare_single_mode(self.mode) + assert self.mode in self._dist_main_progs, \ "predict model is not ready, please call `engine.prepare()` first." test_dataloader = self._create_dataloader(test_data, batch_size) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index b00f1a589e3..108b99fdce6 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -113,12 +113,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): filter_vars.append(varname) # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(backward_op.desc) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) dist_op_desc.set_input('X', filter_vars) dist_op_desc.set_output('Out', filter_vars) - main_block._sync_with_cpp() # sync result group = new_process_group(world_process_group.ranks) @@ -155,7 +154,6 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): "out_dtype": inf_var.dtype, OP_ROLE_KEY: OpRole.Optimize }) - main_block._sync_with_cpp() for op in [cast_op1, allreduce_op, cast_op2]: new_op_dist_attr = OperatorDistributedAttribute() diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index a2b1b7826d5..9d9d5371aca 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -363,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): output_name) # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): @@ -371,8 +371,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) - main_block._sync_with_cpp() - # data parallel synchronization for primtive operators from paddle.incubate.autograd import prim_enabled if prim_enabled(): @@ -431,8 +429,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): op_attr.set_input_dims_mapping(param.name, dims_mapping) ctx.set_op_dist_attr_for_program(new_op, op_attr) - startup_block._sync_with_cpp() - @staticmethod def backward(ctx, *args, **kwargs): @@ -461,7 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): output_name) # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(backward_op.desc) # Refer to the related dist op set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) @@ -470,8 +466,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in backward_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) - main_block._sync_with_cpp() - # check if need gradient allreduce # if there is a non-gradient & non-parameter input and its batch dimension is splited, # we need insert gradient allreduce for the gradient of parameter in its output @@ -552,8 +546,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): dims_mapping) ctx.set_op_dist_attr_for_program(op, op_attr) - main_block._sync_with_cpp() - register_distributed_operator_impl( "default", DistributedDefaultImpl0("replicate_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 80c9b8641ba..aa463398139 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -312,7 +312,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Forward }) - startup_block._sync_with_cpp() @staticmethod def backward(ctx, *args, **kwargs): @@ -412,8 +411,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx) - main_block._sync_with_cpp() - c_embedding_grad_op_desc = main_block.desc.append_op() + c_embedding_grad_op_desc = main_block.append_op(type='nop').desc c_embedding_grad_op_desc.set_type("c_embedding_grad") c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name]) c_embedding_grad_op_desc.set_input('W', [Weight_var.name]) @@ -422,7 +420,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name]) c_embedding_grad_op_desc._set_attr('start_index', relative_idx) c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) - main_block._sync_with_cpp() c_embedding_grad_op = main_block.ops[-1] assert c_embedding_grad_op.type == "c_embedding_grad" diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index 763e47802b3..27e8983707b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -118,7 +118,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): shape_list[idx] = shape_list[idx] // process_mesh_shape[axis] op._set_attr("shape", shape_list) - main_block._sync_with_cpp() @staticmethod def backward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 0826148208e..4e9aefd168c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -38,7 +38,7 @@ from .dist_default import DistributedDefaultImpl0 def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): - dist_op_desc = block.desc.append_op() + dist_op_desc = block.append_op(type='nop').desc dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): @@ -48,7 +48,6 @@ def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): assert input_name in kwargs dist_op_desc.set_output(output_name, kwargs[output_name]) - block._sync_with_cpp() return dist_op_desc @@ -387,8 +386,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): matmul_op_desc = copy_op_with_new_input_output(ctx, main_block, backward_op, **kwargs) - main_block._sync_with_cpp() - # check if need gradient allreduce need_gradient_allreduce = False @@ -468,7 +465,6 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Forward }) - startup_block._sync_with_cpp() class DistributedMatmul(DistributedOperatorImplContainer): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index 4629e4bef93..7eea4bea49f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -248,7 +248,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): # rename input kwargs['X'] = [allgather_out.name] # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): @@ -260,8 +260,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): allgather_out.name, allgather_out_dist_attr.dims_mapping) ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr) - main_block._sync_with_cpp() - @staticmethod def backward(ctx, *args, **kwargs): @@ -305,7 +303,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var) ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr) # replicate op in dist program with new kwargs - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(backward_op.desc) # Refer to the related dist op set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) @@ -319,7 +317,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): op_dist_attr.set_output_dims_mapping(new_X_grad.name, new_X_var_dist_attr.dims_mapping) ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr) - main_block._sync_with_cpp() # 2. insert slice op process_mesh_shape = op_dist_attr.process_mesh.topology @@ -359,7 +356,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): slice_op_dist_attr.set_output_dims_mapping(X_grad_var.name, X_grad_var_dims_mapping) ctx.set_op_dist_attr_for_program(slice_op, slice_op_dist_attr) - main_block._sync_with_cpp() register_distributed_operator_impl("p_norm", diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py index 6d750562c96..bdd105ef64c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py @@ -109,14 +109,13 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): output_name) # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): dist_op_desc.set_input(input_name, kwargs[input_name]) for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) - main_block._sync_with_cpp() # batch dimension synchronization var_name = src_op.output_arg_names[0] diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 47a783a5f6d..790e97cf4e1 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -177,7 +177,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): idx] = shape_list[idx] // process_mesh_shape[axis] # create op - new_op_desc = main_block.desc.append_op() + new_op_desc = main_block.append_op(type='nop').desc new_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) @@ -187,8 +187,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): new_op_desc.set_output('Out', [Out_var.name]) new_op_desc._set_attr('shape', shape_list) - main_block._sync_with_cpp() - @staticmethod def backward(ctx, *args, **kwargs): DistributedDefaultImpl0.backward(ctx, *args, **kwargs) @@ -335,7 +333,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): idx] = shape_list[idx] // process_mesh_shape[axis] # create op - new_op_desc = main_block.desc.append_op() + new_op_desc = main_block.append_op(type='nop').desc new_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) @@ -345,8 +343,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): new_op_desc.set_output('Out', [Out_var.name]) new_op_desc._set_attr('shape', shape_list) - main_block._sync_with_cpp() - @staticmethod def backward(ctx, *args, **kwargs): DistributedDefaultImpl0.backward(ctx, *args, **kwargs) @@ -486,7 +482,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): idx] = shape_list[idx] // process_mesh_shape[axis] # create op - new_op_desc = main_block.desc.append_op() + new_op_desc = main_block.append_op(type='nop').desc new_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) @@ -496,8 +492,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): new_op_desc.set_output('Out', [Out_var.name]) new_op_desc._set_attr('shape', shape_list) - main_block._sync_with_cpp() - @staticmethod def backward(ctx, *args, **kwargs): DistributedDefaultImpl0.backward(ctx, *args, **kwargs) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index 9666f882200..cbbcaef5ee4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -127,12 +127,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): filter_vars.append(varname) # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(backward_op.desc) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) dist_op_desc.set_input('X', filter_vars) dist_op_desc.set_output('Out', filter_vars) - main_block._sync_with_cpp() register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index d8c0da9e270..005e51dfce7 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +import time +import logging from collections import defaultdict import paddle @@ -20,6 +22,7 @@ from paddle.fluid import program_guard from paddle.fluid.backward import append_backward from paddle.fluid.framework import _non_static_mode from paddle.distributed.passes import new_pass +from paddle.distributed.utils import get_logger from .reshard import Resharder from .partitioner import Partitioner @@ -41,6 +44,7 @@ class Parallelizer: assert self._dist_context._is_initialized self._pass_context = self._dist_context.pass_context self._strategy = self._dist_context.strategy + self._logger = get_logger(logging.INFO) def parallel_all(self): world_process_group = get_world_process_group() @@ -61,38 +65,65 @@ class Parallelizer: serial_startup_program, serial_loss) # Apply pre optimization passes + time0 = time.time() self._apply_pre_optimization(serial_main_program, serial_startup_program, serial_loss, serial_optimizer, params_grads) - + self._logger.info( + "within parallel apply_pre_optimization time: {}, mode {}". + format(time.time() - time0, self._mode)) # Do logical partition + time0 = time.time() partitioner = Partitioner(self._dist_context, rank) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( serial_main_program, serial_startup_program, params_grads) + self._logger.info( + "within parallel partitioner time: {}, mode {}".format( + time.time() - time0, self._mode)) # Generate optimizer + time0 = time.time() self._generate_optimizer(dist_main_prog, dist_startup_prog, serial_optimizer, dist_params_grads) + self._logger.info( + "within parallel optimizer time: {}, mode {}".format( + time.time() - time0, self._mode)) # Do reshard process + time0 = time.time() set_grad_var_shape(dist_main_prog, self._dist_context) resharder = Resharder(dist_main_prog, dist_startup_prog, rank, self._dist_context, dist_params_grads) resharder.reshard() + self._logger.info( + "within parallel reshard time: {}, mode {}".format( + time.time() - time0, self._mode)) # Apply post optimization passes + time0 = time.time() self._apply_post_optimization(dist_main_prog, dist_startup_prog, rank, dist_params_grads) + self._logger.info( + "within parallel apply_post_optimization time: {}, mode {}". + format(time.time() - time0, self._mode)) else: # Apply pre optimization passes # self._apply_pre_optimization(serial_main_program, # serial_startup_program, None, None, # None) # Do logical partition + time0 = time.time() partitioner = Partitioner(self._dist_context, rank) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( serial_main_program, serial_startup_program, []) # Do reshard process + self._logger.info( + "within parallel partitioner time: {}, mode {}".format( + time.time() - time0, self._mode)) + time0 = time.time() resharder = Resharder(dist_main_prog, dist_startup_prog, rank, self._dist_context, [], 1) resharder.reshard() + self._logger.info( + "within parallel reshard time: {}, mode {}".format( + time.time() - time0, self._mode)) # Clone program for test if self._mode != 'train': dist_main_prog = dist_main_prog.clone(for_test=True) diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 66cce97533e..717f8fa27f2 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -58,7 +58,7 @@ def _remove_and_get_optimizer_op(main_program, dist_context): def _remove_op_role_var(param, grad): op_maker = core.op_proto_and_checker_maker op = grad.op - if op.has_attr(op_maker.kOpRoleVarAttrName()): + if op and op.has_attr(op_maker.kOpRoleVarAttrName()): op._remove_attr(op_maker.kOpRoleVarAttrName()) -- GitLab