diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index c109c861a5e579288fc0f51a5da3c578d8e989aa..02a8c17247534732a175a5ab5472799152cc67a3 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -980,7 +980,7 @@ class Completer: # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() else: - self._logger.info("Default data parallel will be set.") + self._logger.info("Default distributed attributed will be set.") self._dist_context.initialize(with_graph=False) # A fast and special completion for data parallel self._update_dist_attr_for_dp() @@ -1050,6 +1050,17 @@ class Completer: for arg_name in serial_op.output_arg_names: op_dist_attr = dist_op.dist_attr serial_tensor = dist_op.get_serial_output(arg_name) + if serial_op.type in ["fill_constant"]: + old_dims_mapping = op_dist_attr.get_output_dims_mapping( + arg_name + ) + if len(old_dims_mapping) > 0: + new_dims_mapping = [0] + [ + -1 for _ in range(len(old_dims_mapping) - 1) + ] + op_dist_attr.set_output_dims_mapping( + arg_name, new_dims_mapping + ) dist_tensor = self._dist_context.get_dist_tensor_for_program( serial_tensor ) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index 76ce8240473135d60bc099c01ff3b9ab120fd9a1..b0d31f12dace9965191f6fb92ba8c84393ee0c50 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -39,10 +39,16 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] + out_name = op_desc.output('Out')[0] + in_var = dist_op.serial_op.block.var(in_name) + out_var = dist_op.serial_op.block.var(out_name) axes = op_desc.attr('axes') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) for axis in axes: - if is_dim_shard(in_dims_mapping[axis]): + if ( + is_dim_shard(in_dims_mapping[axis]) + and in_var.shape[axis] != out_var.shape[axis] + ): return False return True @@ -51,6 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] + in_var = dist_op.serial_op.block.var(in_name) + out_var = dist_op.serial_op.block.var(out_name) axes = op_desc.attr('axes') decrease_axis = op_desc.attr('decrease_axis') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) @@ -67,7 +75,11 @@ class DistributedSliceImpl(DistributedOperatorImpl): else: for i in range(len(out_dims_mapping)): ref_index = ref_indices[i] - if ref_index in axes and is_dim_shard(out_dims_mapping[i]): + if ( + ref_index in axes + and is_dim_shard(out_dims_mapping[i]) + and in_var.shape[ref_index] != out_var.shape[ref_index] + ): return False return True diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py index 7e4d5eaee90dd90e046da922a20e2b0cfba0f7f8..089951d66bdce2a92d3fa51413ec8be569d0ed3c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py @@ -32,6 +32,7 @@ def make_program_dp2(): tmp_1 = x[:, 0, :] tmp_2 = x[:, :, 1] tmp_3 = x[:2, :2, :2] + tmp_3 = x[:4, :2, :2] return main_program, start_program