From bbdc168331edde0db6da676fb33e6bc01c79a7ef Mon Sep 17 00:00:00 2001 From: pangengzheng <117730991+pangengzheng@users.noreply.github.com> Date: Tue, 1 Aug 2023 10:15:11 +0800 Subject: [PATCH] fix cuda mem in sharding parallel (#55653) --- python/paddle/distributed/fleet/base/topology.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 0272fdd086d..4f5e3566148 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -187,10 +187,11 @@ class HybridCommunicateGroup: "data" ) - ( - self.sharding_check_group, - self.sharding_check_comm_group, - ) = self._set_check_group("sharding") + if self._sharding_degree > 1: + ( + self.sharding_check_group, + self.sharding_check_comm_group, + ) = self._set_check_group("sharding") # create p2p group self.is_first_stage = self.stage_id == 0 -- GitLab