diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index e905a4c1fc5fced1eec62c0cebb6d8aa2e623dc7..f9221f4bb7621a659a405ec352df17eb2d287133 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -184,7 +184,10 @@ class GroupShardedOptimizerStage2(Optimizer): # Enable gradients' reduces overlap with backward calculation. self._reduce_overlap = reduce_overlap - def _set_broadcast_overlap(self, broadcast_overlap, layers=None): + def _set_broadcast_overlap(self, + broadcast_overlap, + layers=None, + num_groups=None): # Enable post optimizer broadcasts overlap with the forward calculation of next batch. self._broadcast_overlap = broadcast_overlap if self._broadcast_overlap: @@ -202,6 +205,27 @@ class GroupShardedOptimizerStage2(Optimizer): "overlap broadcast may harm the performance.") self._broadcast_order_params = self._local_params + if num_groups is None or num_groups > len(self._broadcast_order_params): + warnings.warn( + "The num_groups for broadcast is larger than the number of params to be broadcast. " + "It will set to default value: 1 (use the default sharding group)." + ) + num_groups = 1 + + assert isinstance( + num_groups, + int) and num_groups > 0, "num_groups should be a positive integer" + + self._number_of_broadcast_groups = num_groups + self._broadcast_groups = [ + None for _ in range(self._number_of_broadcast_groups) + ] + self._broadcast_groups[0] = self._group + + ranks = self._group.ranks + for i in range(1, self._number_of_broadcast_groups): + self._broadcast_groups[i] = new_group(ranks) + def _generate_master_params(self, trainable_params): if self.offload: for param in trainable_params: @@ -484,14 +508,17 @@ class GroupShardedOptimizerStage2(Optimizer): def _broadcast_params_overlap_forward(self): # Exchange all the shards with the other ranks, # but overlap the broadcast with next batch's calculation. + group_idx = 0 + param2task = {} for x in self._broadcast_order_params: if x.trainable: - task = broadcast( - tensor=x, - src=self._group.ranks[self._param2rank[x.name]], - group=self._group, - sync_op=False) + group = self._broadcast_groups[group_idx] + group_idx = (group_idx + 1) % self._number_of_broadcast_groups + task = broadcast(tensor=x, + src=group.ranks[self._param2rank[x.name]], + group=group, + sync_op=False) assert x.name not in param2task param2task[x.name] = task