未验证 提交 bc3ca678 编写于 作者: B Baibaifan 提交者: GitHub

Fix sharding group (#39668)

* fix_sharding_group

* fix_sharding_group
上级 e674af23
......@@ -98,7 +98,7 @@ class ShardingOptimizerStage2(Optimizer):
self.world_size = self.group.nranks
self.rank = self.group.rank
self._global_root_rank = 0
self._global_root_rank = self.group.ranks[0]
# Synchronous all ranks models
if pertrain_sync_models:
......@@ -403,7 +403,7 @@ class ShardingOptimizerStage2(Optimizer):
for dst_rank, internal_storage in dtype_per_rank.items():
dist.broadcast(
tensor=internal_storage.buffer,
src=dst_rank,
src=self.group.ranks[dst_rank],
group=self.group,
use_calc_stream=True)
......
......@@ -85,7 +85,8 @@ class ShardingStage2(nn.Layer):
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = self._group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_root_rank = self._group.ranks[
0] # picking rank 0 as the reference
self._default_device = device
# Global statistical parameters
......@@ -319,7 +320,7 @@ class ShardingStage2(nn.Layer):
Taskflow(
task=dist.reduce(
tensor=param.grad,
dst=dst_rank,
dst=self._group.ranks[dst_rank],
group=self._group,
use_calc_stream=True),
callback=cleanup))
......@@ -377,7 +378,8 @@ class ShardingStage2(nn.Layer):
Taskflow(
task=dist.reduce(
tensor=grad_storage.buffer,
dst=grad_storage.destination,
dst=self._group.ranks[
grad_storage.destination],
group=self._group,
use_calc_stream=True),
callback=cleanup))
......
......@@ -101,7 +101,8 @@ class ShardingStage3(nn.Layer):
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1."
self._rank = self._group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_root_rank = self._group.ranks[
0] # picking rank 0 as the reference
self._global_ranks = self._group.ranks
# Parameter segmentation for global ranks
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册