未验证 提交 3a14857b 编写于 作者: Z zhaoyingli 提交者: GitHub

fix dp completion (#47804)

上级 831db343
......@@ -980,7 +980,7 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
else:
self._logger.info("Default data parallel will be set.")
self._logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel
self._update_dist_attr_for_dp()
......@@ -1050,6 +1050,17 @@ class Completer:
for arg_name in serial_op.output_arg_names:
op_dist_attr = dist_op.dist_attr
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_op.type in ["fill_constant"]:
old_dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name
)
if len(old_dims_mapping) > 0:
new_dims_mapping = [0] + [
-1 for _ in range(len(old_dims_mapping) - 1)
]
op_dist_attr.set_output_dims_mapping(
arg_name, new_dims_mapping
)
dist_tensor = self._dist_context.get_dist_tensor_for_program(
serial_tensor
)
......
......@@ -39,10 +39,16 @@ class DistributedSliceImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
in_var = dist_op.serial_op.block.var(in_name)
out_var = dist_op.serial_op.block.var(out_name)
axes = op_desc.attr('axes')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
for axis in axes:
if is_dim_shard(in_dims_mapping[axis]):
if (
is_dim_shard(in_dims_mapping[axis])
and in_var.shape[axis] != out_var.shape[axis]
):
return False
return True
......@@ -51,6 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
in_var = dist_op.serial_op.block.var(in_name)
out_var = dist_op.serial_op.block.var(out_name)
axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
......@@ -67,7 +75,11 @@ class DistributedSliceImpl(DistributedOperatorImpl):
else:
for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i]
if ref_index in axes and is_dim_shard(out_dims_mapping[i]):
if (
ref_index in axes
and is_dim_shard(out_dims_mapping[i])
and in_var.shape[ref_index] != out_var.shape[ref_index]
):
return False
return True
......
......@@ -32,6 +32,7 @@ def make_program_dp2():
tmp_1 = x[:, 0, :]
tmp_2 = x[:, :, 1]
tmp_3 = x[:2, :2, :2]
tmp_3 = x[:4, :2, :2]
return main_program, start_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册