未验证 提交 02146ba5 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] align infer accuracy for ernie generator mode (#40077)

* [Auto Parallel] Support the auto completion of while_op
* align infer  accuracy
上级 f027b2ad
......@@ -65,7 +65,8 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
in_name = op_desc.input('Input')[0]
......@@ -78,7 +79,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
x_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......
......@@ -1837,6 +1837,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
out_var_dist_attr)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
......@@ -1936,7 +1938,6 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
......
......@@ -329,13 +329,7 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs)
# set dist attr uid
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# param.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(param, dist_attr)
return param
def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
......@@ -352,13 +346,7 @@ def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer)
# set dist attr uid
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# var.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(var, dist_attr)
return var
def _partition_var(dist_context, src_block, dst_block, src_varname,
......@@ -369,7 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
src_var = src_block.var(src_varname)
if src_var.type in __not_shape_var_type__:
dst_block.create_var(
new_var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=True,
......@@ -380,11 +368,17 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter):
_partition_parameter(dist_context, src_var, dst_block, dst_varname,
target_shape)
new_var = _partition_parameter(dist_context, src_var, dst_block,
dst_varname, target_shape)
else:
_partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape)
new_var = _partition_intermediate_var(
dist_context, src_var, dst_block, dst_varname, target_shape)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)
return target_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册