未验证 提交 ffd8adca 编写于 作者: H Haohongxiang 提交者: GitHub

fix_bugs_of_sharding (#44982)

上级 031debb7
...@@ -1856,9 +1856,8 @@ class Fleet(object): ...@@ -1856,9 +1856,8 @@ class Fleet(object):
group=None) group=None)
self._found_inf = is_found_inf.numpy()[0] self._found_inf = is_found_inf.numpy()[0]
# Only tensor_parallel and pipeline_parallel need to modify scaler # Only data_parallel doesn't need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL, if self._hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL:
ParallelMode.PIPELINE_PARALLEL):
scaler._unscale = MethodType(unscale_method, scaler) scaler._unscale = MethodType(unscale_method, scaler)
return scaler return scaler
...@@ -159,7 +159,7 @@ def group_sharded_parallel(model, ...@@ -159,7 +159,7 @@ def group_sharded_parallel(model,
sync_comm=sync_comm) sync_comm=sync_comm)
else: else:
raise ValueError("Please enter the correct level.") raise ValueError("Please enter the correct level.")
if params_fp16 and isinstance(scaler, paddle.amp.GradScaler): if isinstance(scaler, paddle.amp.GradScaler):
if in_dygraph_mode(): if in_dygraph_mode():
scaler = GroupShardedScaler(scaler) scaler = GroupShardedScaler(scaler)
else: else:
......
...@@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000): ...@@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000):
return __reader__ return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False): def optimizer_setting(model, use_multi_precision, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
parameters=[{ parameters=[{
...@@ -70,16 +70,23 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): ...@@ -70,16 +70,23 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
learning_rate=0.001, learning_rate=0.001,
weight_decay=0.00001, weight_decay=0.00001,
grad_clip=clip, grad_clip=clip,
multi_precision=use_pure_fp16) multi_precision=use_multi_precision)
return optimizer return optimizer
def train_mlp(model, shard_level, use_pure_fp16, output_dir): def train_mlp(model,
shard_level,
use_multi_precision,
output_dir,
amp_level='O1'):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) optimizer = optimizer_setting(model=model,
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') use_multi_precision=use_multi_precision)
model = paddle.amp.decorate(models=model,
level=amp_level,
save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
model, optimizer, scaler = group_sharded_parallel(model=model, model, optimizer, scaler = group_sharded_parallel(model=model,
...@@ -104,13 +111,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir): ...@@ -104,13 +111,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir):
img, label = data img, label = data
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'): with paddle.amp.auto_cast(True, level=amp_level):
out = model(img) out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, loss = paddle.nn.functional.cross_entropy(input=out,
label=label) label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16: if not use_multi_precision:
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
else: else:
...@@ -135,12 +142,36 @@ def test_sharding_api(): ...@@ -135,12 +142,36 @@ def test_sharding_api():
# fp16 # fp16
stage2_params = train_mlp(mlp1, stage2_params = train_mlp(mlp1,
shard_level="os_g", shard_level="os_g",
use_pure_fp16=True, use_multi_precision=True,
output_dir=output_dir) output_dir=output_dir,
amp_level='O2')
stage3_params = train_mlp(mlp2, stage3_params = train_mlp(mlp2,
shard_level="p_g_os", shard_level="p_g_os",
use_pure_fp16=True, use_multi_precision=True,
output_dir=output_dir) output_dir=output_dir,
amp_level='O2')
for i in range(len(stage3_params)):
np.testing.assert_allclose(stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3)
# AMP
mlp3, mlp4 = MLP(), MLP()
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
stage2_params = train_mlp(mlp3,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
stage3_params = train_mlp(mlp4,
shard_level="p_g_os",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
np.testing.assert_allclose(stage2_params[i].numpy(), np.testing.assert_allclose(stage2_params[i].numpy(),
......
...@@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000): ...@@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000):
return __reader__ return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False): def optimizer_setting(model, use_multi_precision, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
parameters=[{ parameters=[{
...@@ -70,14 +70,21 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): ...@@ -70,14 +70,21 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
learning_rate=0.001, learning_rate=0.001,
weight_decay=0.00001, weight_decay=0.00001,
grad_clip=clip, grad_clip=clip,
multi_precision=use_pure_fp16) multi_precision=use_multi_precision)
return optimizer return optimizer
def train_mlp(model, shard_level, use_pure_fp16, output_dir): def train_mlp(model,
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) shard_level,
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') use_multi_precision,
output_dir,
amp_level='O1'):
optimizer = optimizer_setting(model=model,
use_multi_precision=use_multi_precision)
model = paddle.amp.decorate(models=model,
level=amp_level,
save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
model, optimizer, scaler = group_sharded_parallel(model=model, model, optimizer, scaler = group_sharded_parallel(model=model,
...@@ -102,13 +109,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir): ...@@ -102,13 +109,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir):
img, label = data img, label = data
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'): with paddle.amp.auto_cast(True, level=amp_level):
out = model(img) out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, loss = paddle.nn.functional.cross_entropy(input=out,
label=label) label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16: if not use_multi_precision:
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
else: else:
...@@ -134,19 +141,36 @@ def test_sharding_api(): ...@@ -134,19 +141,36 @@ def test_sharding_api():
# fp16 # fp16
stage2_params = train_mlp(mlp1, stage2_params = train_mlp(mlp1,
shard_level="os_g", shard_level="os_g",
use_pure_fp16=True, use_multi_precision=True,
output_dir=output_dir) output_dir=output_dir,
amp_level='O2')
stage3_params = train_mlp(mlp2, stage3_params = train_mlp(mlp2,
shard_level="p_g_os", shard_level="p_g_os",
use_pure_fp16=True, use_multi_precision=True,
output_dir=output_dir) output_dir=output_dir,
amp_level='O2')
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
np.testing.assert_allclose(stage2_params[i].numpy(), np.testing.assert_allclose(stage2_params[i].numpy(),
stage3_params[i].numpy(), stage3_params[i].numpy(),
rtol=1e-4, rtol=1e-4,
atol=1e-3) atol=1e-3)
shutil.rmtree(output_dir)
# AMP
mlp3, mlp4 = MLP(), MLP()
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
stage2_params = train_mlp(mlp3,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
stage3_params = train_mlp(mlp4,
shard_level="p_g_os",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O1')
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册