未验证 提交 3a14857b 编写于 作者: Z zhaoyingli 提交者: GitHub

fix dp completion (#47804)

上级 831db343
...@@ -980,7 +980,7 @@ class Completer: ...@@ -980,7 +980,7 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program # Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.copy_dist_attr_from_graph_to_program()
else: else:
self._logger.info("Default data parallel will be set.") self._logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False) self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel # A fast and special completion for data parallel
self._update_dist_attr_for_dp() self._update_dist_attr_for_dp()
...@@ -1050,6 +1050,17 @@ class Completer: ...@@ -1050,6 +1050,17 @@ class Completer:
for arg_name in serial_op.output_arg_names: for arg_name in serial_op.output_arg_names:
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
if serial_op.type in ["fill_constant"]:
old_dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name
)
if len(old_dims_mapping) > 0:
new_dims_mapping = [0] + [
-1 for _ in range(len(old_dims_mapping) - 1)
]
op_dist_attr.set_output_dims_mapping(
arg_name, new_dims_mapping
)
dist_tensor = self._dist_context.get_dist_tensor_for_program( dist_tensor = self._dist_context.get_dist_tensor_for_program(
serial_tensor serial_tensor
) )
......
...@@ -39,10 +39,16 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -39,10 +39,16 @@ class DistributedSliceImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0] in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
in_var = dist_op.serial_op.block.var(in_name)
out_var = dist_op.serial_op.block.var(out_name)
axes = op_desc.attr('axes') axes = op_desc.attr('axes')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
for axis in axes: for axis in axes:
if is_dim_shard(in_dims_mapping[axis]): if (
is_dim_shard(in_dims_mapping[axis])
and in_var.shape[axis] != out_var.shape[axis]
):
return False return False
return True return True
...@@ -51,6 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -51,6 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0] in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
in_var = dist_op.serial_op.block.var(in_name)
out_var = dist_op.serial_op.block.var(out_name)
axes = op_desc.attr('axes') axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis') decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
...@@ -67,7 +75,11 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -67,7 +75,11 @@ class DistributedSliceImpl(DistributedOperatorImpl):
else: else:
for i in range(len(out_dims_mapping)): for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i] ref_index = ref_indices[i]
if ref_index in axes and is_dim_shard(out_dims_mapping[i]): if (
ref_index in axes
and is_dim_shard(out_dims_mapping[i])
and in_var.shape[ref_index] != out_var.shape[ref_index]
):
return False return False
return True return True
......
...@@ -32,6 +32,7 @@ def make_program_dp2(): ...@@ -32,6 +32,7 @@ def make_program_dp2():
tmp_1 = x[:, 0, :] tmp_1 = x[:, 0, :]
tmp_2 = x[:, :, 1] tmp_2 = x[:, :, 1]
tmp_3 = x[:2, :2, :2] tmp_3 = x[:2, :2, :2]
tmp_3 = x[:4, :2, :2]
return main_program, start_program return main_program, start_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册