diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index 08baeae89ad4a7fb4e8e7067b18cc457074b980a..112c3887fcfa5d383d61a0dd32f6a0a73e5aea92 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -98,7 +98,7 @@ class ShardingOptimizerStage2(Optimizer): self.world_size = self.group.nranks self.rank = self.group.rank - self._global_root_rank = 0 + self._global_root_rank = self.group.ranks[0] # Synchronous all ranks models if pertrain_sync_models: @@ -403,7 +403,7 @@ class ShardingOptimizerStage2(Optimizer): for dst_rank, internal_storage in dtype_per_rank.items(): dist.broadcast( tensor=internal_storage.buffer, - src=dst_rank, + src=self.group.ranks[dst_rank], group=self.group, use_calc_stream=True) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index e654f88f0b7b8c2a6473cfc534f4e09fde3a7632..392a7f3ac5d8fe01ec2b5fdf0b36030d79124be4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -85,7 +85,8 @@ class ShardingStage2(nn.Layer): self._world_size_scaling = 1.0 / self._group.nranks assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1" self._rank = self._group.rank - self._global_root_rank = 0 # picking rank 0 as the reference + self._global_root_rank = self._group.ranks[ + 0] # picking rank 0 as the reference self._default_device = device # Global statistical parameters @@ -319,7 +320,7 @@ class ShardingStage2(nn.Layer): Taskflow( task=dist.reduce( tensor=param.grad, - dst=dst_rank, + dst=self._group.ranks[dst_rank], group=self._group, use_calc_stream=True), callback=cleanup)) @@ -377,7 +378,8 @@ class ShardingStage2(nn.Layer): Taskflow( task=dist.reduce( tensor=grad_storage.buffer, - dst=grad_storage.destination, + dst=self._group.ranks[ + grad_storage.destination], group=self._group, use_calc_stream=True), callback=cleanup)) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 9f9811b9eb0fcf07866d21f70bb99040f2179c8c..de69836fdba14da14c179c5e00fa4cc14481e534 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -101,7 +101,8 @@ class ShardingStage3(nn.Layer): self._world_size_scaling = 1.0 / self._group.nranks assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1." self._rank = self._group.rank - self._global_root_rank = 0 # picking rank 0 as the reference + self._global_root_rank = self._group.ranks[ + 0] # picking rank 0 as the reference self._global_ranks = self._group.ranks # Parameter segmentation for global ranks