未验证 提交 9012e8bc 编写于 作者: C Charles-hit 提交者: GitHub

fix distributed bug caused by fill_any_like (#45978) (#46041)

上级 526e0323
......@@ -1037,7 +1037,7 @@ class Completer:
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping)
elif grad_op.type == 'fill_zeros_like':
elif grad_op.type == 'fill_any_like':
ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
......@@ -1274,7 +1274,7 @@ class Completer:
grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0
elif grad_op.type == 'fill_zeros_like':
elif grad_op.type == 'fill_any_like':
ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
......
......@@ -348,7 +348,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
else:
batch_dim_mappings.append(dims_mapping[1])
for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like":
if op_desc.type() == 'fill_any_like':
input_tensor = dist_op.get_serial_input(
op_desc.input_arg_names()[0])
if input_tensor.is_parameter:
......@@ -387,7 +387,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
dims_mapping[1] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like":
if op_desc.type() == 'fill_any_like':
input_tensor = dist_op.get_serial_input(
op_desc.input_arg_names()[0])
if input_tensor.is_parameter:
......
......@@ -1040,7 +1040,7 @@ def set_grad_var_shape(program, dist_context):
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast",
"fill_zeros_like"
'fill_any_like'
]:
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册