未验证 提交 818132a0 编写于 作者: W wuhuachaocoding 提交者: GitHub

support sharding stage2 + mp hybrid_parallel. (#47535)

* support sharding stage2 + mp hybrid_parallel.

* fix the group of check_nan_inf.

* update hcg.
上级 605bc003
...@@ -31,6 +31,11 @@ import paddle ...@@ -31,6 +31,11 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed import fleet, ParallelMode
HybridParallelClipGrad = (
fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer.HybridParallelClipGrad
)
from paddle.distributed.collective import ( from paddle.distributed.collective import (
_get_global_group, _get_global_group,
broadcast, broadcast,
...@@ -157,6 +162,15 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -157,6 +162,15 @@ class GroupShardedOptimizerStage2(Optimizer):
"While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed." "While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed."
) )
hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
if (
hcg
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
):
self._optim._grad_clip = HybridParallelClipGrad(
self._optim._grad_clip, hcg
)
else:
self._optim._grad_clip = GroupShardedClipGrad( self._optim._grad_clip = GroupShardedClipGrad(
self._optim._grad_clip, paddle.get_device(), self._group self._optim._grad_clip, paddle.get_device(), self._group
) )
......
...@@ -23,6 +23,7 @@ from paddle.fluid import core ...@@ -23,6 +23,7 @@ from paddle.fluid import core
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only from paddle.fluid.framework import dygraph_only
from paddle.distributed import fleet, ParallelMode
class Taskflow: class Taskflow:
...@@ -244,10 +245,18 @@ def GroupShardedScaler(scaler): ...@@ -244,10 +245,18 @@ def GroupShardedScaler(scaler):
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
hybrid_parallel = (
hcg is not None
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
)
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, is_found_inf,
op=paddle.distributed.ReduceOp.MAX, op=paddle.distributed.ReduceOp.MAX,
group=optimizer._group, group=hcg.get_check_parallel_group()
if hybrid_parallel
else optimizer._group,
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = is_found_inf.numpy()[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册