未验证 提交 81e06752 编写于 作者: C Charles-hit 提交者: GitHub

fix distributed bug caused by fill_any_like (#45978)

上级 b2122239
...@@ -1037,7 +1037,7 @@ class Completer: ...@@ -1037,7 +1037,7 @@ class Completer:
grad_op_dist_attr.set_output_dims_mapping( grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping) output_name, ref_fwd_dims_mapping)
elif grad_op.type == 'fill_zeros_like': elif grad_op.type == 'fill_any_like':
ref_var_name = grad_op.input_arg_names[0] ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name] ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
...@@ -1274,7 +1274,7 @@ class Completer: ...@@ -1274,7 +1274,7 @@ class Completer:
grad_op_dist_attr.impl_type = "default" grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0 grad_op_dist_attr.impl_idx = 0
elif grad_op.type == 'fill_zeros_like': elif grad_op.type == 'fill_any_like':
ref_var_name = grad_op.input_arg_names[0] ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name] ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
......
...@@ -348,7 +348,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -348,7 +348,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
else: else:
batch_dim_mappings.append(dims_mapping[1]) batch_dim_mappings.append(dims_mapping[1])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like": if op_desc.type() == 'fill_any_like':
input_tensor = dist_op.get_serial_input( input_tensor = dist_op.get_serial_input(
op_desc.input_arg_names()[0]) op_desc.input_arg_names()[0])
if input_tensor.is_parameter: if input_tensor.is_parameter:
...@@ -387,7 +387,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -387,7 +387,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
dims_mapping[1] = compatible_dim_mapping dims_mapping[1] = compatible_dim_mapping
changed = True changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like": if op_desc.type() == 'fill_any_like':
input_tensor = dist_op.get_serial_input( input_tensor = dist_op.get_serial_input(
op_desc.input_arg_names()[0]) op_desc.input_arg_names()[0])
if input_tensor.is_parameter: if input_tensor.is_parameter:
......
...@@ -1040,7 +1040,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1040,7 +1040,7 @@ def set_grad_var_shape(program, dist_context):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast", "c_allreduce_sum", "c_identity", "scale", "cast",
"fill_zeros_like" 'fill_any_like'
]: ]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册