From 02146ba529d58a274e2b252ab8f5514a12b4efcf Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 25 Mar 2022 17:46:06 +0800 Subject: [PATCH] [Auto parallel] align infer accuracy for ernie generator mode (#40077) * [Auto Parallel] Support the auto completion of while_op * align infer accuracy --- .../dist_fill_constant_batch_size_like.py | 5 +-- .../auto_parallel/operators/dist_matmul.py | 3 +- .../distributed/auto_parallel/partitioner.py | 32 ++++++++----------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index e71ece47abf..80ac019e830 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -65,7 +65,8 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): if (not self.is_input_compatible(dist_op)) or \ (not self.is_output_compatible(dist_op)): return False - + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) in_name = op_desc.input('Input')[0] @@ -78,7 +79,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): changed = False op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] + x_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 6f526a4d323..68167c1c4f7 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -1837,6 +1837,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): out_var_dist_attr) intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_allreduce_sum", 'tmp'])), shape=Out_var.shape, dtype=Out_var.dtype, type=Out_var.type, @@ -1936,7 +1938,6 @@ class DistributedMulImpl2(DistributedOperatorImpl): if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(x_dims_mapping[-2]): return False - if is_dim_shard(y_dims_mapping[-1]): return False if is_valid_list_index(y_dims_mapping, diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index ed5ec85d84f..e27c859c4e2 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -329,13 +329,7 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname, belong_to_optimizer=src_var.belong_to_optimizer, **copied_kwargs) - # set dist attr uid - # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() - # param.desc.set_distributed_attr_uid(distributed_attr_uid) - dist_attr = copy.deepcopy( - dist_context.get_tensor_dist_attr_for_program(src_var)) - assert dist_attr is not None - dist_context.set_tensor_dist_attr_for_program(param, dist_attr) + return param def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, @@ -352,13 +346,7 @@ def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, is_data=src_var.is_data, belong_to_optimizer=src_var.belong_to_optimizer) - # set dist attr uid - # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() - # var.desc.set_distributed_attr_uid(distributed_attr_uid) - dist_attr = copy.deepcopy( - dist_context.get_tensor_dist_attr_for_program(src_var)) - assert dist_attr is not None - dist_context.set_tensor_dist_attr_for_program(var, dist_attr) + return var def _partition_var(dist_context, src_block, dst_block, src_varname, @@ -369,7 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, src_var = src_block.var(src_varname) if src_var.type in __not_shape_var_type__: - dst_block.create_var( + new_var = dst_block.create_var( type=src_var.type, name=dst_varname, persistable=True, @@ -380,11 +368,17 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, target_shape = _get_dist_shape(src_var, dist_attr) if isinstance(src_var, Parameter): - _partition_parameter(dist_context, src_var, dst_block, dst_varname, - target_shape) + new_var = _partition_parameter(dist_context, src_var, dst_block, + dst_varname, target_shape) else: - _partition_intermediate_var(dist_context, src_var, dst_block, - dst_varname, target_shape) + new_var = _partition_intermediate_var( + dist_context, src_var, dst_block, dst_varname, target_shape) + + dist_attr = copy.deepcopy( + dist_context.get_tensor_dist_attr_for_program(src_var)) + assert dist_attr is not None + dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr) + return target_shape -- GitLab