From 818132a0ac7f634ca286c8d8eac0073790e4a302 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Thu, 3 Nov 2022 16:34:13 +0800 Subject: [PATCH] support sharding stage2 + mp hybrid_parallel. (#47535) * support sharding stage2 + mp hybrid_parallel. * fix the group of check_nan_inf. * update hcg. --- .../group_sharded_optimizer_stage2.py | 20 ++++++++++++++++--- .../sharding/group_sharded_utils.py | 11 +++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 2668b58089..0414798e68 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -31,6 +31,11 @@ import paddle from paddle.fluid import core from paddle.optimizer import Optimizer from paddle.fluid.clip import ClipGradByGlobalNorm +from paddle.distributed import fleet, ParallelMode + +HybridParallelClipGrad = ( + fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer.HybridParallelClipGrad +) from paddle.distributed.collective import ( _get_global_group, broadcast, @@ -157,9 +162,18 @@ class GroupShardedOptimizerStage2(Optimizer): "While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed." ) - self._optim._grad_clip = GroupShardedClipGrad( - self._optim._grad_clip, paddle.get_device(), self._group - ) + hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None + if ( + hcg + and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL + ): + self._optim._grad_clip = HybridParallelClipGrad( + self._optim._grad_clip, hcg + ) + else: + self._optim._grad_clip = GroupShardedClipGrad( + self._optim._grad_clip, paddle.get_device(), self._group + ) if self._optim._parameter_list and isinstance( self._optim._parameter_list[0], dict ): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 4107bd83db..6832e9a7ca 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -23,6 +23,7 @@ from paddle.fluid import core from paddle.fluid import layers from paddle.fluid.dygraph import to_variable from paddle.fluid.framework import dygraph_only +from paddle.distributed import fleet, ParallelMode class Taskflow: @@ -244,10 +245,18 @@ def GroupShardedScaler(scaler): self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None + hybrid_parallel = ( + hcg is not None + and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL + ) + paddle.distributed.all_reduce( is_found_inf, op=paddle.distributed.ReduceOp.MAX, - group=optimizer._group, + group=hcg.get_check_parallel_group() + if hybrid_parallel + else optimizer._group, ) self._found_inf = is_found_inf.numpy()[0] -- GitLab