diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 5840c16fc019c3e57029bf1cedd8233a2f930763..636b3218c8a0b5635e9a7abc85afcd95fc976955 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -396,7 +396,7 @@ class ShardingPass(PassBase): dp_ring_ids = [group.id for group in self.dp_groups] for idx, op in reversed(list(enumerate(main_block.ops))): - if is_data_parallel_reduce_op(op): + if _is_param_grad_allreduce_op(op, main_block): input_name = op.input_arg_names[0] base_name = _get_base_name_from_grad_name(input_name) sharding_info = self.varname_to_sharding_info[base_name] @@ -653,6 +653,20 @@ def _get_base_name_from_grad_name(grad_name): return base_name +def _is_param_grad_allreduce_op(op, block): + + if not is_data_parallel_reduce_op(op): + return False + + output_name = op.output_arg_names[0] + base_name = _get_base_name_from_grad_name(output_name) + + if not block.has_var(base_name): + return False + + return block.var(base_name).is_parameter + + def _is_param_grad_sum_op(op, block): if not is_backward_op(op):