From bdd0b0f18f62ba101351c0cd68792259a7d80bb9 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 25 Aug 2022 13:14:39 +0800 Subject: [PATCH] [Auto Parallel] Support High Order Differential with Data Parallel Calc-Comm Overlaping (#45388) * support high order differential with data parallel overlap * update unitest --- .../distributed/auto_parallel/dist_context.py | 12 +++++ .../auto_parallel/operators/common.py | 12 +++++ .../auto_parallel/operators/dist_pnorm.py | 49 ------------------- ...uto_parallel_data_parallel_optimization.py | 3 +- .../auto_parallel/test_dist_pnorm.py | 4 +- 5 files changed, 27 insertions(+), 53 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 52065ff4927..92a50365904 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -24,6 +24,7 @@ from .dist_attribute import OperatorDistributedAttribute from .dist_tensor import DistributedTensor from .dist_op import DistributedOperator from .process_mesh import ProcessMesh +from .utils import is_loss_grad_op, is_loss_op # There always exists a default context for user. And user can set it to another one. _g_default_distributed_context = None @@ -895,6 +896,11 @@ class DistributedOperatorContext: self.already_init_sync_vars = set() self.varname_mapping = None self.rank_id = None + # NOTE Support correct parallelism for high-order differential model. + # by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True, + # it means we are in Backward phase. + # And the final sulotion should be revise high-order differential logic for these two phases in future. + self._exceed_backward_init_op = False def __deepcopy__(self, memo): cls = self.__class__ @@ -951,10 +957,16 @@ class DistributedOperatorContext: assert self._cur_src_op is not None return self._cur_src_op + def in_backward_phase(self): + return self._exceed_backward_init_op + def prepare_context(self, src_op): self._cur_src_op = src_op + if is_loss_grad_op(src_op): + self._exceed_backward_init_op = True + # build input varname mapping kinputs = {} for input_name in src_op.desc.input_names(): diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index b34749b09df..e7e7ad1e0ea 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -428,6 +428,9 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, rank (int): global ranks index for current process. """ + if not is_in_backward_phase(dist_ctx): + return + if is_optimize_op(op) or len(act_grad_names) == 0 or len( out_grad_names) == 0: return @@ -448,3 +451,12 @@ def is_data_parallel_scale_op(op): def is_data_parallel_reduce_op(op): return op.type in ["c_reduce_sum", "c_allreduce_sum"] and op.desc.has_attr("op_namescope") \ and ParallelMode.DataParallel in op.desc.attr("op_namescope") + + +def is_in_backward_phase(dist_ctx): + # NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators + # in Forward phase and operators in Backward phase (both with op_role=1), which will mislead + # auto parallel to add gradient synchronization for gradient computation operators in Forward phase. + # we use this FLAG to distinguish these two phases temporarily. + + return dist_ctx.dist_op_context.in_backward_phase() diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index 7eea4bea49f..77efa7fe67d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -42,29 +42,6 @@ class DistributedPNorm(DistributedOperatorImplContainer): register_distributed_operator_impl_container(DistributedPNorm("p_norm")) -def _insert_fill_constant_op(block, op_role): - """Insert fill constant op into block at the given index.""" - helper = LayerHelper("fill_constant", **locals()) - with paddle.static.program_guard(block.program): - out = helper.create_variable_for_type_inference(dtype="int32") - inputs = {} - attrs = {'force_cpu': False} - attrs['str_value'] = str(int("1")) - attrs['value'] = int("1") - attrs['dtype'] = out.dtype - attrs['op_role'] = op_role - utils.get_shape_tensor_inputs(inputs=inputs, - attrs=attrs, - shape=[0], - op_type='fill_constant') - fill_constant_op = block.append_op(type='fill_constant', - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) - out.stop_gradient = True - return out, fill_constant_op - - # Row Parallel class DistributedPNormImpl(DistributedOperatorImpl): @@ -182,32 +159,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'norm') - # 1. insert barrier op - ref_process_mesh = op_dist_attr.process_mesh - constant_out_dims_mapping = [-1] - fill_constant_out, fill_constant_op = _insert_fill_constant_op( - main_block, src_op.attr('op_role')) - # set fill_constant_out tensor dist_attr - constant_out_dist_attr = TensorDistributedAttribute() - constant_out_dist_attr.process_mesh = ref_process_mesh - constant_out_dist_attr.dims_mapping = constant_out_dims_mapping - ctx.set_tensor_dist_attr_for_program(fill_constant_out, - constant_out_dist_attr) - # set fill_constant op dist_attr - constant_op_dist_attr = OperatorDistributedAttribute() - constant_op_dist_attr.process_mesh = ref_process_mesh - constant_op_dist_attr.set_output_dims_mapping( - fill_constant_out.name, constant_out_dims_mapping) - ctx.set_op_dist_attr_for_program(fill_constant_op, - constant_op_dist_attr) - barrier_op = main_block.append_op(type='barrier', - inputs={'X': [fill_constant_out]}, - outputs={'Out': [fill_constant_out]}, - attrs={'ring_id': group.id}) - # set barrier op dist attr - set_comm_op_dist_attr_for_program(barrier_op, ref_process_mesh, - constant_out_dist_attr, ctx) - # 2. insert c_allgather op # create c_allgather output var allgather_out = main_block.create_var( diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index d91fe644c98..98fed754fa2 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -111,9 +111,9 @@ class DataParallelOptimizationPass(PassBase): scaled_grads = [] for op in ops: - grad_name = op.output_arg_names[0] if is_data_parallel_reduce_op(op): + grad_name = op.output_arg_names[0] if grad_name in self._grad_name_to_group_map: continue assert op.has_attr( @@ -132,6 +132,7 @@ class DataParallelOptimizationPass(PassBase): self._group_to_grad_name_map[group].append(grad_name) elif is_data_parallel_scale_op(op): + grad_name = op.output_arg_names[0] scaled_grads.append(grad_name) # TODO support multiple optimizers in on network in future. diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py index 74664062303..dfddba3dda1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py @@ -108,10 +108,8 @@ class TestDistPNorm(unittest.TestCase): for output_attr in op_dist_attr.outputs_dist_attrs.values(): assert output_attr.dims_mapping[0] == 0 assert set(output_attr.dims_mapping[1:]) == set([-1]) - assert op_types == [ - "fill_constant", "barrier", "c_allgather", "p_norm", - "fill_constant", "p_norm_grad", "slice" + "c_allgather", "p_norm", "fill_constant", "p_norm_grad", "slice" ] def test_dist_pnorm_serial(self): -- GitLab