未验证 提交 02146ba5 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] align infer accuracy for ernie generator mode (#40077)

* [Auto Parallel] Support the auto completion of while_op
* align infer  accuracy
上级 f027b2ad
...@@ -65,7 +65,8 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -65,7 +65,8 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)): (not self.is_output_compatible(dist_op)):
return False return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
in_name = op_desc.input('Input')[0] in_name = op_desc.input('Input')[0]
...@@ -78,7 +79,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -78,7 +79,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......
...@@ -1837,6 +1837,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -1837,6 +1837,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
out_var_dist_attr) out_var_dist_attr)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape, shape=Out_var.shape,
dtype=Out_var.dtype, dtype=Out_var.dtype,
type=Out_var.type, type=Out_var.type,
...@@ -1936,7 +1938,6 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -1936,7 +1938,6 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if is_valid_list_index(x_dims_mapping, if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]): -2) and is_dim_shard(x_dims_mapping[-2]):
return False return False
if is_dim_shard(y_dims_mapping[-1]): if is_dim_shard(y_dims_mapping[-1]):
return False return False
if is_valid_list_index(y_dims_mapping, if is_valid_list_index(y_dims_mapping,
......
...@@ -329,13 +329,7 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname, ...@@ -329,13 +329,7 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
belong_to_optimizer=src_var.belong_to_optimizer, belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs) **copied_kwargs)
# set dist attr uid return param
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# param.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(param, dist_attr)
def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
...@@ -352,13 +346,7 @@ def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, ...@@ -352,13 +346,7 @@ def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
is_data=src_var.is_data, is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer) belong_to_optimizer=src_var.belong_to_optimizer)
# set dist attr uid return var
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# var.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(var, dist_attr)
def _partition_var(dist_context, src_block, dst_block, src_varname, def _partition_var(dist_context, src_block, dst_block, src_varname,
...@@ -369,7 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -369,7 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
src_var = src_block.var(src_varname) src_var = src_block.var(src_varname)
if src_var.type in __not_shape_var_type__: if src_var.type in __not_shape_var_type__:
dst_block.create_var( new_var = dst_block.create_var(
type=src_var.type, type=src_var.type,
name=dst_varname, name=dst_varname,
persistable=True, persistable=True,
...@@ -380,11 +368,17 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -380,11 +368,17 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
target_shape = _get_dist_shape(src_var, dist_attr) target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter): if isinstance(src_var, Parameter):
_partition_parameter(dist_context, src_var, dst_block, dst_varname, new_var = _partition_parameter(dist_context, src_var, dst_block,
target_shape) dst_varname, target_shape)
else: else:
_partition_intermediate_var(dist_context, src_var, dst_block, new_var = _partition_intermediate_var(
dst_varname, target_shape) dist_context, src_var, dst_block, dst_varname, target_shape)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)
return target_shape return target_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册