From bc3ca6786668c732c451d1742e1aee08fac3cb1f Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Fri, 18 Feb 2022 11:34:36 +0800 Subject: [PATCH] Fix sharding group (#39668) * fix_sharding_group * fix_sharding_group --- .../dygraph_optimizer/sharding_optimizer_stage2.py | 4 ++-- .../fleet/meta_parallel/sharding/sharding_stage2.py | 8 +++++--- .../fleet/meta_parallel/sharding/sharding_stage3.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index 08baeae89ad..112c3887fcf 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -98,7 +98,7 @@ class ShardingOptimizerStage2(Optimizer): self.world_size = self.group.nranks self.rank = self.group.rank - self._global_root_rank = 0 + self._global_root_rank = self.group.ranks[0] # Synchronous all ranks models if pertrain_sync_models: @@ -403,7 +403,7 @@ class ShardingOptimizerStage2(Optimizer): for dst_rank, internal_storage in dtype_per_rank.items(): dist.broadcast( tensor=internal_storage.buffer, - src=dst_rank, + src=self.group.ranks[dst_rank], group=self.group, use_calc_stream=True) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index e654f88f0b7..392a7f3ac5d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -85,7 +85,8 @@ class ShardingStage2(nn.Layer): self._world_size_scaling = 1.0 / self._group.nranks assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1" self._rank = self._group.rank - self._global_root_rank = 0 # picking rank 0 as the reference + self._global_root_rank = self._group.ranks[ + 0] # picking rank 0 as the reference self._default_device = device # Global statistical parameters @@ -319,7 +320,7 @@ class ShardingStage2(nn.Layer): Taskflow( task=dist.reduce( tensor=param.grad, - dst=dst_rank, + dst=self._group.ranks[dst_rank], group=self._group, use_calc_stream=True), callback=cleanup)) @@ -377,7 +378,8 @@ class ShardingStage2(nn.Layer): Taskflow( task=dist.reduce( tensor=grad_storage.buffer, - dst=grad_storage.destination, + dst=self._group.ranks[ + grad_storage.destination], group=self._group, use_calc_stream=True), callback=cleanup)) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 9f9811b9eb0..de69836fdba 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -101,7 +101,8 @@ class ShardingStage3(nn.Layer): self._world_size_scaling = 1.0 / self._group.nranks assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1." self._rank = self._group.rank - self._global_root_rank = 0 # picking rank 0 as the reference + self._global_root_rank = self._group.ranks[ + 0] # picking rank 0 as the reference self._global_ranks = self._group.ranks # Parameter segmentation for global ranks -- GitLab