From e65cdaeed31c2df6d1d3a919b91cc00efb51b698 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 28 Sep 2022 20:40:00 +0800 Subject: [PATCH] [AutoParallel] fix sharding (#46572) --- .../distributed/passes/auto_parallel_sharding.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 5840c16fc01..636b3218c8a 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -396,7 +396,7 @@ class ShardingPass(PassBase): dp_ring_ids = [group.id for group in self.dp_groups] for idx, op in reversed(list(enumerate(main_block.ops))): - if is_data_parallel_reduce_op(op): + if _is_param_grad_allreduce_op(op, main_block): input_name = op.input_arg_names[0] base_name = _get_base_name_from_grad_name(input_name) sharding_info = self.varname_to_sharding_info[base_name] @@ -653,6 +653,20 @@ def _get_base_name_from_grad_name(grad_name): return base_name +def _is_param_grad_allreduce_op(op, block): + + if not is_data_parallel_reduce_op(op): + return False + + output_name = op.output_arg_names[0] + base_name = _get_base_name_from_grad_name(output_name) + + if not block.has_var(base_name): + return False + + return block.var(base_name).is_parameter + + def _is_param_grad_sum_op(op, block): if not is_backward_op(op): -- GitLab