From df944772be935edcd140776da71188e6a5bd2fb6 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 28 Dec 2022 11:08:46 +0800 Subject: [PATCH] [AutoParallel] adapt for clip (#49249) * [AutoParallel] adapt for clip * fix unittest * enable_static * fix dist_fill_constant_batch_size_like * fix process_mesh.shape * update cond of modifying shape_list --- .../auto_parallel/operators/dist_default.py | 21 +++ .../dist_fill_constant_batch_size_like.py | 18 --- .../auto_parallel/operators/dist_pnorm.py | 51 +++++-- .../paddle/distributed/auto_parallel/utils.py | 4 + .../auto_parallel/test_dist_pnorm.py | 134 ++++++++++++------ 5 files changed, 157 insertions(+), 71 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index fe607356339..71581705e74 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -39,6 +39,7 @@ from .common import ( ) __op_not_need_param_init__ = ["while", "cond"] +__op_has_shape_attr__ = ["fill_constant_batch_size_like", "fill_constant"] def prim_operator_data_parallel_functor(ctx, src_op): @@ -476,6 +477,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) + if ( + src_op.has_attr('shape') + and src_op.attr('shape') + and src_op.type in __op_has_shape_attr__ + ): + shape_list = src_op.attr('shape') + Out_var = main_block._var_recursive(kwargs['Out'][0]) + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.process_mesh.shape + assert len(shape_list) == len(dim_mapping) + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = ( + shape_list[idx] // process_mesh_shape[axis] + ) + dist_op_desc._set_attr('shape', shape_list) + # data parallel synchronization for primtive operators from paddle.incubate.autograd import prim_enabled diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index 7aa3fecfc05..9cadbf40b4b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -129,24 +129,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ DistributedDefaultImpl0.forward(ctx, *args, **kwargs) - dist_op_context = ctx.dist_op_context - src_op = dist_op_context.cur_src_op - op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - main_block = dist_op_context.work_block - op = main_block.ops[-1] - assert op.type == "fill_constant_batch_size_like" - - # modify shape attr according to how output are partitioned - out_name = op.output('Out')[0] - dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - process_mesh_shape = op_dist_attr.process_mesh.shape - shape_list = op.attr("shape") - # modify target shape - for idx, axis in enumerate(dims_mapping): - if axis >= 0: - shape_list[idx] = shape_list[idx] // process_mesh_shape[axis] - - op._set_attr("shape", shape_list) @staticmethod def backward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index 53e2278bd6c..ec9bb1d6396 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -47,8 +47,23 @@ class DistributedPNorm(DistributedOperatorImplContainer): register_distributed_operator_impl_container(DistributedPNorm("p_norm")) -# Row Parallel -class DistributedPNormImpl(DistributedOperatorImpl): +# Data Parallel +class DistributedPNormImpl0(DistributedOperatorImpl): + """ + TODO: p_norm scene + + 1. axis == None, isinstance(p, (int, float)), asvector = True + 1.1 x_dims_mapping == [0, -1, -1] + allgather input if it is splited by dp group + 1.2 x_dims_mapping == [-1, 0, -1] + allgather, split and concat input if it is splited by mp group + 2. isinstance(axis, int), asvector = False + 1.1 axis == 0 and x_dims_mapping == [0, -1, -1] + allgather input if it's input[0] is splited by dp group. + 1.2 axis == 1 and x_dims_mapping == [-1, 0, -1] + allgather, split and concat input if it's input[1] is splited by mp group + """ + def __init__(self, name): super().__init__(name) self._forward_implemented = True @@ -57,6 +72,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr + axis = op_desc.attr('axis') + asvector = op_desc.attr('asvector') x_name = op_desc.input('X')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) if is_dim_replicate(x_dims_mapping[0]): @@ -65,6 +82,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): for mapping in x_dims_mapping[1:]: if is_dim_shard(mapping): return False + if not (axis == -1 and asvector) and not (axis == 0 and not asvector): + return False return True def is_output_compatible(self, dist_op): @@ -90,6 +109,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): changed = False op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr + axis = op_desc.attr('axis') + keepdim = op_desc.attr('keepdim') batch_dim_mappings = [] for arg_name in op_desc.input_arg_names(): @@ -115,14 +136,22 @@ class DistributedPNormImpl(DistributedOperatorImpl): ): dims_mapping[0] = compatible_dim_mapping changed = True - for arg_name in op_desc.output_arg_names(): - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if ( - len(dims_mapping) >= 1 - and compatible_dim_mapping != dims_mapping[0] - ): - dims_mapping[0] = compatible_dim_mapping - changed = True + + if axis == 0 and not keepdim: + for arg_name in op_desc.output_arg_names(): + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if len(dims_mapping) >= 1 and dims_mapping[0] != -1: + dims_mapping[0] = -1 + changed = True + else: + for arg_name in op_desc.output_arg_names(): + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if ( + len(dims_mapping) >= 1 + and compatible_dim_mapping != dims_mapping[0] + ): + dims_mapping[0] = compatible_dim_mapping + changed = True return changed @@ -350,5 +379,5 @@ class DistributedPNormImpl(DistributedOperatorImpl): register_distributed_operator_impl( - "p_norm", DistributedPNormImpl("row_parallel") + "p_norm", DistributedPNormImpl0("data_parallel") ) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 80721b0c7fb..358ccb66857 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1261,6 +1261,8 @@ def set_grad_var_shape(program, dist_context): "fused_softmax_mask_upper_triangle_grad", "flatten_contiguous_range_grad", "relu_grad", + "exp_grad", + "sigmoid_grad", ] forward_list = [ "reshape2", @@ -1279,6 +1281,8 @@ def set_grad_var_shape(program, dist_context): "fused_softmax_mask_upper_triangle", "flatten_contiguous_range", "relu", + "exp", + "sigmoid", ] if op.type in need_set_shape_list: for forward_op in block.ops: diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py index 969bab0a69f..c2286c0d693 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py @@ -22,7 +22,7 @@ from paddle.fluid.backward import append_backward paddle.enable_static() -def make_program_dp2(): +def make_program_dp2_axis_None(): main_program = paddle.fluid.Program() start_program = paddle.fluid.Program() with paddle.static.program_guard(main_program, start_program): @@ -35,6 +35,32 @@ def make_program_dp2(): return main_program, start_program, tmp_0 +def make_program_dp2_axis_0(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') + x.stop_gradient = False + auto.shard_tensor( + x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None] + ) + tmp_0 = paddle.norm(x, p=2, axis=0) + return main_program, start_program, tmp_0 + + +def make_program_dp2_axis_1(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') + x.stop_gradient = False + auto.shard_tensor( + x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None] + ) + tmp_0 = paddle.norm(x, p=2, axis=1) + return main_program, start_program, tmp_0 + + def make_program_serial(): main_program = paddle.fluid.Program() start_program = paddle.fluid.Program() @@ -76,47 +102,71 @@ def parallelizer(program_func, rank): class TestDistPNorm(unittest.TestCase): - def test_dist_pnorm_dp2(self): - - for rank in range(2): - dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) - ops = dist_main_prog.global_block().ops - op_types = [] - for op in ops: - op_types.append(op.type) - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - if op.type == "p_norm": - assert op_dist_attr.impl_type == "p_norm" - if op.type in ["p_norm", "p_norm_grad"]: - for input_attr in op_dist_attr.inputs_dist_attrs.values(): - assert set(input_attr.dims_mapping) == set([-1]) - for output_attr in op_dist_attr.outputs_dist_attrs.values(): - assert set(output_attr.dims_mapping) == set([-1]) - if op.type == 'c_allgather': - for input_attr in op_dist_attr.inputs_dist_attrs.values(): - assert input_attr.dims_mapping[0] == 0 - assert set(input_attr.dims_mapping[1:]) == set([-1]) - for output_attr in op_dist_attr.outputs_dist_attrs.values(): - assert set(output_attr.dims_mapping) == set([-1]) - if op.type == 'slice': - for input_attr in op_dist_attr.inputs_dist_attrs.values(): - assert set(input_attr.dims_mapping) == set([-1]) - for output_attr in op_dist_attr.outputs_dist_attrs.values(): - assert output_attr.dims_mapping[0] == 0 - assert set(output_attr.dims_mapping[1:]) == set([-1]) - assert op_types == [ - "c_allgather", - "p_norm", - "fill_constant", - "p_norm_grad", - "slice", - ] - - def test_dist_pnorm_serial(self): - dist_main_prog, dist_context = parallelizer(make_program_serial, 0) - ops = dist_main_prog.global_block().ops - for op in ops: - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + def prepare(self, func): + self.dist_main_prog, self.dist_context = parallelizer(func, 0) + self.ops = self.dist_main_prog.global_block().ops + + def test_dist_pnorm(self): + pass + + +class TestDistPNormDP(TestDistPNorm): + def test_dist_pnorm(self): + self.prepare(make_program_dp2_axis_None) + self.check_program() + + def check_program(self): + op_types = [] + for op in self.ops: + op_types.append(op.type) + op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op) + if op.type == "p_norm": + assert op_dist_attr.impl_type == "p_norm" + if op.type in ["p_norm", "p_norm_grad"]: + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + assert set(input_attr.dims_mapping) == set([-1]) + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + assert set(output_attr.dims_mapping) == set([-1]) + if op.type == 'c_allgather': + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + assert input_attr.dims_mapping[0] == 0 + assert set(input_attr.dims_mapping[1:]) == set([-1]) + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + assert set(output_attr.dims_mapping) == set([-1]) + if op.type == 'slice': + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + assert set(input_attr.dims_mapping) == set([-1]) + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + assert output_attr.dims_mapping[0] == 0 + assert set(output_attr.dims_mapping[1:]) == set([-1]) + assert op_types == [ + "c_allgather", + "p_norm", + "fill_constant", + "p_norm_grad", + "slice", + ] + + +class TestDistPNormDP1(TestDistPNormDP): + def test_dist_pnorm(self): + self.prepare(make_program_dp2_axis_0) + self.check_program() + + +class TestDistPNormSerial(TestDistPNorm): + def test_dist_pnorm(self): + self.prepare(make_program_serial) + for op in self.ops: + op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_type == "default" + + +class TestDistPNormDPAxis1(TestDistPNorm): + def test_dist_pnorm(self): + self.prepare(make_program_dp2_axis_1) + for op in self.ops: + op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr.impl_type == "default" -- GitLab