未验证 提交 1c8531ce 编写于 作者: W wuhuachaocoding 提交者: GitHub

fix a bug of stage2 offload. (#49767)

上级 d58cca9e
...@@ -149,6 +149,11 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -149,6 +149,11 @@ class GroupShardedOptimizerStage2(Optimizer):
self._rank = self._group.rank self._rank = self._group.rank
self._global_root_rank = self._group.ranks[0] self._global_root_rank = self._group.ranks[0]
if self._dp_group is not None and self._dp_group.nranks > 1:
assert (
not offload
), "Not support! when using offload with sharding stage2, please use pure sharding stage2, exclude data parallel."
# Synchronous all ranks models # Synchronous all ranks models
if pertrain_sync_models: if pertrain_sync_models:
self._sync_params_and_buffers() self._sync_params_and_buffers()
...@@ -164,6 +169,7 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -164,6 +169,7 @@ class GroupShardedOptimizerStage2(Optimizer):
if ( if (
hcg hcg
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
and not offload
): ):
self._optim._grad_clip = HybridParallelClipGrad( self._optim._grad_clip = HybridParallelClipGrad(
self._optim._grad_clip, hcg self._optim._grad_clip, hcg
......
...@@ -37,17 +37,29 @@ np.random.seed(seed) ...@@ -37,17 +37,29 @@ np.random.seed(seed)
paddle.seed(seed) paddle.seed(seed)
def train_mlp(model, offload=False): def train_mlp(model, offload=False, test=False):
optimizer = optimizer_setting(model=model, use_pure_fp16=True) optimizer = optimizer_setting(model=model, use_pure_fp16=True)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = GroupShardedScaler(scaler) scaler = GroupShardedScaler(scaler)
dp_group = (
None
if not test
else paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size()))
)
)
optimizer = GroupShardedOptimizerStage2( optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, offload=offload params=optimizer._parameter_list,
optim=optimizer,
offload=offload,
dp_group=dp_group,
)
model = GroupShardedStage2(
model, optimizer, buffer_max_size=2**21, dp_group=dp_group
) )
model = GroupShardedStage2(model, optimizer, buffer_max_size=2**21)
paddle.seed(2023) paddle.seed(2023)
np.random.seed(2023) np.random.seed(2023)
...@@ -103,6 +115,13 @@ def test_sharding_stage2_offload(): ...@@ -103,6 +115,13 @@ def test_sharding_stage2_offload():
rtol=5e-3, rtol=5e-3,
atol=5e-3, atol=5e-3,
) )
# just to test assert error for the rate of coverage
try:
train_mlp(mlp_offload, offload=True, test=True)
except Exception as e:
assert isinstance(e, AssertionError)
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册