未验证 提交 ffd8adca 编写于 作者: H Haohongxiang 提交者: GitHub

fix_bugs_of_sharding (#44982)

上级 031debb7
......@@ -1856,9 +1856,8 @@ class Fleet(object):
group=None)
self._found_inf = is_found_inf.numpy()[0]
# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
ParallelMode.PIPELINE_PARALLEL):
# Only data_parallel doesn't need to modify scaler
if self._hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL:
scaler._unscale = MethodType(unscale_method, scaler)
return scaler
......@@ -159,7 +159,7 @@ def group_sharded_parallel(model,
sync_comm=sync_comm)
else:
raise ValueError("Please enter the correct level.")
if params_fp16 and isinstance(scaler, paddle.amp.GradScaler):
if isinstance(scaler, paddle.amp.GradScaler):
if in_dygraph_mode():
scaler = GroupShardedScaler(scaler)
else:
......
......@@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000):
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
def optimizer_setting(model, use_multi_precision, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum(
parameters=[{
......@@ -70,16 +70,23 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
multi_precision=use_multi_precision)
return optimizer
def train_mlp(model, shard_level, use_pure_fp16, output_dir):
def train_mlp(model,
shard_level,
use_multi_precision,
output_dir,
amp_level='O1'):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
optimizer = optimizer_setting(model=model,
use_multi_precision=use_multi_precision)
model = paddle.amp.decorate(models=model,
level=amp_level,
save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
model, optimizer, scaler = group_sharded_parallel(model=model,
......@@ -104,13 +111,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
with paddle.amp.auto_cast(True, level=amp_level):
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out,
label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
if not use_multi_precision:
avg_loss.backward()
optimizer.step()
else:
......@@ -135,12 +142,36 @@ def test_sharding_api():
# fp16
stage2_params = train_mlp(mlp1,
shard_level="os_g",
use_pure_fp16=True,
output_dir=output_dir)
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2')
stage3_params = train_mlp(mlp2,
shard_level="p_g_os",
use_pure_fp16=True,
output_dir=output_dir)
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2')
for i in range(len(stage3_params)):
np.testing.assert_allclose(stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3)
# AMP
mlp3, mlp4 = MLP(), MLP()
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
stage2_params = train_mlp(mlp3,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
stage3_params = train_mlp(mlp4,
shard_level="p_g_os",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
for i in range(len(stage3_params)):
np.testing.assert_allclose(stage2_params[i].numpy(),
......
......@@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000):
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
def optimizer_setting(model, use_multi_precision, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum(
parameters=[{
......@@ -70,14 +70,21 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
multi_precision=use_multi_precision)
return optimizer
def train_mlp(model, shard_level, use_pure_fp16, output_dir):
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
def train_mlp(model,
shard_level,
use_multi_precision,
output_dir,
amp_level='O1'):
optimizer = optimizer_setting(model=model,
use_multi_precision=use_multi_precision)
model = paddle.amp.decorate(models=model,
level=amp_level,
save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
model, optimizer, scaler = group_sharded_parallel(model=model,
......@@ -102,13 +109,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
with paddle.amp.auto_cast(True, level=amp_level):
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out,
label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
if not use_multi_precision:
avg_loss.backward()
optimizer.step()
else:
......@@ -134,19 +141,36 @@ def test_sharding_api():
# fp16
stage2_params = train_mlp(mlp1,
shard_level="os_g",
use_pure_fp16=True,
output_dir=output_dir)
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2')
stage3_params = train_mlp(mlp2,
shard_level="p_g_os",
use_pure_fp16=True,
output_dir=output_dir)
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2')
for i in range(len(stage3_params)):
np.testing.assert_allclose(stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3)
shutil.rmtree(output_dir)
# AMP
mlp3, mlp4 = MLP(), MLP()
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
stage2_params = train_mlp(mlp3,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
stage3_params = train_mlp(mlp4,
shard_level="p_g_os",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册