未验证 提交 1c8531ce 编写于 作者: W wuhuachaocoding 提交者: GitHub

fix a bug of stage2 offload. (#49767)

上级 d58cca9e
......@@ -149,6 +149,11 @@ class GroupShardedOptimizerStage2(Optimizer):
self._rank = self._group.rank
self._global_root_rank = self._group.ranks[0]
if self._dp_group is not None and self._dp_group.nranks > 1:
assert (
not offload
), "Not support! when using offload with sharding stage2, please use pure sharding stage2, exclude data parallel."
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
......@@ -164,6 +169,7 @@ class GroupShardedOptimizerStage2(Optimizer):
if (
hcg
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
and not offload
):
self._optim._grad_clip = HybridParallelClipGrad(
self._optim._grad_clip, hcg
......
......@@ -37,17 +37,29 @@ np.random.seed(seed)
paddle.seed(seed)
def train_mlp(model, offload=False):
def train_mlp(model, offload=False, test=False):
optimizer = optimizer_setting(model=model, use_pure_fp16=True)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = GroupShardedScaler(scaler)
dp_group = (
None
if not test
else paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size()))
)
)
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, offload=offload
params=optimizer._parameter_list,
optim=optimizer,
offload=offload,
dp_group=dp_group,
)
model = GroupShardedStage2(
model, optimizer, buffer_max_size=2**21, dp_group=dp_group
)
model = GroupShardedStage2(model, optimizer, buffer_max_size=2**21)
paddle.seed(2023)
np.random.seed(2023)
......@@ -103,6 +115,13 @@ def test_sharding_stage2_offload():
rtol=5e-3,
atol=5e-3,
)
# just to test assert error for the rate of coverage
try:
train_mlp(mlp_offload, offload=True, test=True)
except Exception as e:
assert isinstance(e, AssertionError)
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册