未验证 提交 818132a0 编写于 作者: W wuhuachaocoding 提交者: GitHub

support sharding stage2 + mp hybrid_parallel. (#47535)

* support sharding stage2 + mp hybrid_parallel.

* fix the group of check_nan_inf.

* update hcg.
上级 605bc003
......@@ -31,6 +31,11 @@ import paddle
from paddle.fluid import core
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed import fleet, ParallelMode
HybridParallelClipGrad = (
fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer.HybridParallelClipGrad
)
from paddle.distributed.collective import (
_get_global_group,
broadcast,
......@@ -157,9 +162,18 @@ class GroupShardedOptimizerStage2(Optimizer):
"While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed."
)
self._optim._grad_clip = GroupShardedClipGrad(
self._optim._grad_clip, paddle.get_device(), self._group
)
hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
if (
hcg
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
):
self._optim._grad_clip = HybridParallelClipGrad(
self._optim._grad_clip, hcg
)
else:
self._optim._grad_clip = GroupShardedClipGrad(
self._optim._grad_clip, paddle.get_device(), self._group
)
if self._optim._parameter_list and isinstance(
self._optim._parameter_list[0], dict
):
......
......@@ -23,6 +23,7 @@ from paddle.fluid import core
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.distributed import fleet, ParallelMode
class Taskflow:
......@@ -244,10 +245,18 @@ def GroupShardedScaler(scaler):
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
hybrid_parallel = (
hcg is not None
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
)
paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=optimizer._group,
group=hcg.get_check_parallel_group()
if hybrid_parallel
else optimizer._group,
)
self._found_inf = is_found_inf.numpy()[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册