未验证 提交 95768115 编写于 作者: Y Yuang Liu 提交者: GitHub

Multi groups for broadcast of sharding stage 2 (#46894)

上级 a9cc5482
...@@ -184,7 +184,10 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -184,7 +184,10 @@ class GroupShardedOptimizerStage2(Optimizer):
# Enable gradients' reduces overlap with backward calculation. # Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap self._reduce_overlap = reduce_overlap
def _set_broadcast_overlap(self, broadcast_overlap, layers=None): def _set_broadcast_overlap(self,
broadcast_overlap,
layers=None,
num_groups=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch. # Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap: if self._broadcast_overlap:
...@@ -202,6 +205,27 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -202,6 +205,27 @@ class GroupShardedOptimizerStage2(Optimizer):
"overlap broadcast may harm the performance.") "overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params self._broadcast_order_params = self._local_params
if num_groups is None or num_groups > len(self._broadcast_order_params):
warnings.warn(
"The num_groups for broadcast is larger than the number of params to be broadcast. "
"It will set to default value: 1 (use the default sharding group)."
)
num_groups = 1
assert isinstance(
num_groups,
int) and num_groups > 0, "num_groups should be a positive integer"
self._number_of_broadcast_groups = num_groups
self._broadcast_groups = [
None for _ in range(self._number_of_broadcast_groups)
]
self._broadcast_groups[0] = self._group
ranks = self._group.ranks
for i in range(1, self._number_of_broadcast_groups):
self._broadcast_groups[i] = new_group(ranks)
def _generate_master_params(self, trainable_params): def _generate_master_params(self, trainable_params):
if self.offload: if self.offload:
for param in trainable_params: for param in trainable_params:
...@@ -484,13 +508,16 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -484,13 +508,16 @@ class GroupShardedOptimizerStage2(Optimizer):
def _broadcast_params_overlap_forward(self): def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks, # Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation. # but overlap the broadcast with next batch's calculation.
group_idx = 0
param2task = {} param2task = {}
for x in self._broadcast_order_params: for x in self._broadcast_order_params:
if x.trainable: if x.trainable:
task = broadcast( group = self._broadcast_groups[group_idx]
tensor=x, group_idx = (group_idx + 1) % self._number_of_broadcast_groups
src=self._group.ranks[self._param2rank[x.name]], task = broadcast(tensor=x,
group=self._group, src=group.ranks[self._param2rank[x.name]],
group=group,
sync_op=False) sync_op=False)
assert x.name not in param2task assert x.name not in param2task
param2task[x.name] = task param2task[x.name] = task
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册