From 6319dd830f5bfb1ab57a0584176ac83132f6b20a Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 31 May 2022 14:37:05 +0800 Subject: [PATCH] fix bugs (#43115) --- .../fleet/utils/hybrid_parallel_util.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index d0b5c915e11..5e2ad43c164 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -140,17 +140,12 @@ def broadcast_dp_parameters(model, hcg): def fused_allreduce_gradients(parameter_list, hcg): - if _in_legacy_dygraph(): - data_parallel_group = None if hcg is None else hcg.get_data_parallel_group( - ) - logger.debug("dp start fuse allreduce gradients") - with framework.no_grad(): - _apply_collective_grads(parameter_list, data_parallel_group) - elif in_dygraph_mode(): - assert hcg is None, "It's not support to use hcg in EagerDygraph now." - data_parallel_group = paddle.distributed.collective._get_default_group() - with framework.no_grad(): - _apply_collective_grads_eager(parameter_list, data_parallel_group) + data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() + logger.debug("dp start fuse allreduce gradients") + apply_func = _apply_collective_grads_eager if in_dygraph_mode( + ) else _apply_collective_grads + with framework.no_grad(): + apply_func(parameter_list, data_parallel_group) def sharding_reduce_gradients(parameter_list, hcg): -- GitLab