diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 2668b5808996d3598cc1e8f6abb2b27867145bc7..0414798e685141d336f05e59d09a906b565c83b0 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -31,6 +31,11 @@ import paddle from paddle.fluid import core from paddle.optimizer import Optimizer from paddle.fluid.clip import ClipGradByGlobalNorm +from paddle.distributed import fleet, ParallelMode + +HybridParallelClipGrad = ( + fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer.HybridParallelClipGrad +) from paddle.distributed.collective import ( _get_global_group, broadcast, @@ -157,9 +162,18 @@ class GroupShardedOptimizerStage2(Optimizer): "While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed." ) - self._optim._grad_clip = GroupShardedClipGrad( - self._optim._grad_clip, paddle.get_device(), self._group - ) + hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None + if ( + hcg + and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL + ): + self._optim._grad_clip = HybridParallelClipGrad( + self._optim._grad_clip, hcg + ) + else: + self._optim._grad_clip = GroupShardedClipGrad( + self._optim._grad_clip, paddle.get_device(), self._group + ) if self._optim._parameter_list and isinstance( self._optim._parameter_list[0], dict ): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 4107bd83db9363704232b9d9c169dd1a63ad6af4..6832e9a7caa2125852cbdf0a56a6d64e5e294982 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -23,6 +23,7 @@ from paddle.fluid import core from paddle.fluid import layers from paddle.fluid.dygraph import to_variable from paddle.fluid.framework import dygraph_only +from paddle.distributed import fleet, ParallelMode class Taskflow: @@ -244,10 +245,18 @@ def GroupShardedScaler(scaler): self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None + hybrid_parallel = ( + hcg is not None + and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL + ) + paddle.distributed.all_reduce( is_found_inf, op=paddle.distributed.ReduceOp.MAX, - group=optimizer._group, + group=hcg.get_check_parallel_group() + if hybrid_parallel + else optimizer._group, ) self._found_inf = is_found_inf.numpy()[0]