From 81e06752e32e1990731a8fa5e7a25d1474b80d4a Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Tue, 13 Sep 2022 20:12:08 +0800 Subject: [PATCH] fix distributed bug caused by fill_any_like (#45978) --- python/paddle/distributed/auto_parallel/completion.py | 4 ++-- .../distributed/auto_parallel/operators/dist_default.py | 4 ++-- python/paddle/distributed/auto_parallel/utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 1775a823c5..8712288204 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1037,7 +1037,7 @@ class Completer: grad_op_dist_attr.set_output_dims_mapping( output_name, ref_fwd_dims_mapping) - elif grad_op.type == 'fill_zeros_like': + elif grad_op.type == 'fill_any_like': ref_var_name = grad_op.input_arg_names[0] ref_var = vars[ref_var_name] ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( @@ -1274,7 +1274,7 @@ class Completer: grad_op_dist_attr.impl_type = "default" grad_op_dist_attr.impl_idx = 0 - elif grad_op.type == 'fill_zeros_like': + elif grad_op.type == 'fill_any_like': ref_var_name = grad_op.input_arg_names[0] ref_var = vars[ref_var_name] ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 08c81c4a30..a5139e0018 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -348,7 +348,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): else: batch_dim_mappings.append(dims_mapping[1]) for arg_name in op_desc.output_arg_names(): - if op_desc.type() == "fill_zeros_like": + if op_desc.type() == 'fill_any_like': input_tensor = dist_op.get_serial_input( op_desc.input_arg_names()[0]) if input_tensor.is_parameter: @@ -387,7 +387,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): dims_mapping[1] = compatible_dim_mapping changed = True for arg_name in op_desc.output_arg_names(): - if op_desc.type() == "fill_zeros_like": + if op_desc.type() == 'fill_any_like': input_tensor = dist_op.get_serial_input( op_desc.input_arg_names()[0]) if input_tensor.is_parameter: diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index d276df6ddb..cfe2161286 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1040,7 +1040,7 @@ def set_grad_var_shape(program, dist_context): if op.type in [ "c_allreduce_sum", "c_identity", "scale", "cast", - "fill_zeros_like" + 'fill_any_like' ]: forward_var_name = op.input_arg_names[0] elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": -- GitLab