From 3a14857b68a66fe2c3cfe58fd774694d5af089e1 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 10 Nov 2022 15:58:04 +0800 Subject: [PATCH] fix dp completion (#47804) --- .../distributed/auto_parallel/completion.py | 13 ++++++++++++- .../auto_parallel/operators/dist_slice.py | 16 ++++++++++++++-- .../unittests/auto_parallel/test_dist_slice.py | 1 + 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index c109c861a5..02a8c17247 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -980,7 +980,7 @@ class Completer: # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() else: - self._logger.info("Default data parallel will be set.") + self._logger.info("Default distributed attributed will be set.") self._dist_context.initialize(with_graph=False) # A fast and special completion for data parallel self._update_dist_attr_for_dp() @@ -1050,6 +1050,17 @@ class Completer: for arg_name in serial_op.output_arg_names: op_dist_attr = dist_op.dist_attr serial_tensor = dist_op.get_serial_output(arg_name) + if serial_op.type in ["fill_constant"]: + old_dims_mapping = op_dist_attr.get_output_dims_mapping( + arg_name + ) + if len(old_dims_mapping) > 0: + new_dims_mapping = [0] + [ + -1 for _ in range(len(old_dims_mapping) - 1) + ] + op_dist_attr.set_output_dims_mapping( + arg_name, new_dims_mapping + ) dist_tensor = self._dist_context.get_dist_tensor_for_program( serial_tensor ) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index 76ce824047..b0d31f12da 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -39,10 +39,16 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] + out_name = op_desc.output('Out')[0] + in_var = dist_op.serial_op.block.var(in_name) + out_var = dist_op.serial_op.block.var(out_name) axes = op_desc.attr('axes') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) for axis in axes: - if is_dim_shard(in_dims_mapping[axis]): + if ( + is_dim_shard(in_dims_mapping[axis]) + and in_var.shape[axis] != out_var.shape[axis] + ): return False return True @@ -51,6 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] + in_var = dist_op.serial_op.block.var(in_name) + out_var = dist_op.serial_op.block.var(out_name) axes = op_desc.attr('axes') decrease_axis = op_desc.attr('decrease_axis') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) @@ -67,7 +75,11 @@ class DistributedSliceImpl(DistributedOperatorImpl): else: for i in range(len(out_dims_mapping)): ref_index = ref_indices[i] - if ref_index in axes and is_dim_shard(out_dims_mapping[i]): + if ( + ref_index in axes + and is_dim_shard(out_dims_mapping[i]) + and in_var.shape[ref_index] != out_var.shape[ref_index] + ): return False return True diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py index 7e4d5eaee9..089951d66b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py @@ -32,6 +32,7 @@ def make_program_dp2(): tmp_1 = x[:, 0, :] tmp_2 = x[:, :, 1] tmp_3 = x[:2, :2, :2] + tmp_3 = x[:4, :2, :2] return main_program, start_program -- GitLab