diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 1a9b3f565b77ab79a535f255575515223b6b4539..51e1c5281a87f54b0c9c4922fad6bcbcc123dac7 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1856,9 +1856,8 @@ class Fleet(object): group=None) self._found_inf = is_found_inf.numpy()[0] - # Only tensor_parallel and pipeline_parallel need to modify scaler - if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL, - ParallelMode.PIPELINE_PARALLEL): + # Only data_parallel doesn't need to modify scaler + if self._hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL: scaler._unscale = MethodType(unscale_method, scaler) return scaler diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index ad270c1a5173318133a5f2fef16bfaf44c3087ea..58fb51b62b9a3d320f79d8051bf7fa5bdfaa1e79 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -159,7 +159,7 @@ def group_sharded_parallel(model, sync_comm=sync_comm) else: raise ValueError("Please enter the correct level.") - if params_fp16 and isinstance(scaler, paddle.amp.GradScaler): + if isinstance(scaler, paddle.amp.GradScaler): if in_dygraph_mode(): scaler = GroupShardedScaler(scaler) else: diff --git a/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py b/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py index 34b485a8bd4623b1efa8c1173d8531dad34787a2..35be51213607b555b6a428a29b7d720e3603ccf9 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py +++ b/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py @@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000): return __reader__ -def optimizer_setting(model, use_pure_fp16, opt_group=False): +def optimizer_setting(model, use_multi_precision, opt_group=False): clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) optimizer = paddle.optimizer.Momentum( parameters=[{ @@ -70,16 +70,23 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): learning_rate=0.001, weight_decay=0.00001, grad_clip=clip, - multi_precision=use_pure_fp16) + multi_precision=use_multi_precision) return optimizer -def train_mlp(model, shard_level, use_pure_fp16, output_dir): +def train_mlp(model, + shard_level, + use_multi_precision, + output_dir, + amp_level='O1'): group = paddle.distributed.new_group([0, 1]) - optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) - model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') + optimizer = optimizer_setting(model=model, + use_multi_precision=use_multi_precision) + model = paddle.amp.decorate(models=model, + level=amp_level, + save_dtype='float32') scaler = paddle.amp.GradScaler(init_loss_scaling=32768) model, optimizer, scaler = group_sharded_parallel(model=model, @@ -104,13 +111,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir): img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(True, level='O2'): + with paddle.amp.auto_cast(True, level=amp_level): out = model(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) - if not use_pure_fp16: + if not use_multi_precision: avg_loss.backward() optimizer.step() else: @@ -135,12 +142,36 @@ def test_sharding_api(): # fp16 stage2_params = train_mlp(mlp1, shard_level="os_g", - use_pure_fp16=True, - output_dir=output_dir) + use_multi_precision=True, + output_dir=output_dir, + amp_level='O2') stage3_params = train_mlp(mlp2, shard_level="p_g_os", - use_pure_fp16=True, - output_dir=output_dir) + use_multi_precision=True, + output_dir=output_dir, + amp_level='O2') + + for i in range(len(stage3_params)): + np.testing.assert_allclose(stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-4, + atol=1e-3) + + # AMP + mlp3, mlp4 = MLP(), MLP() + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + + stage2_params = train_mlp(mlp3, + shard_level="os_g", + use_multi_precision=True, + output_dir=output_dir, + amp_level='O1') + stage3_params = train_mlp(mlp4, + shard_level="p_g_os", + use_multi_precision=True, + output_dir=output_dir, + amp_level='O1') for i in range(len(stage3_params)): np.testing.assert_allclose(stage2_params[i].numpy(), diff --git a/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api_eager.py b/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api_eager.py index 8f6dadb5ce97890c0045fb6213660ac9a63f5cf2..5de9b5ecea084b8ccc19ccf2a0938dfce573f97b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api_eager.py +++ b/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api_eager.py @@ -61,7 +61,7 @@ def reader_decorator(linear_size=1000): return __reader__ -def optimizer_setting(model, use_pure_fp16, opt_group=False): +def optimizer_setting(model, use_multi_precision, opt_group=False): clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) optimizer = paddle.optimizer.Momentum( parameters=[{ @@ -70,14 +70,21 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): learning_rate=0.001, weight_decay=0.00001, grad_clip=clip, - multi_precision=use_pure_fp16) + multi_precision=use_multi_precision) return optimizer -def train_mlp(model, shard_level, use_pure_fp16, output_dir): - optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) - model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') +def train_mlp(model, + shard_level, + use_multi_precision, + output_dir, + amp_level='O1'): + optimizer = optimizer_setting(model=model, + use_multi_precision=use_multi_precision) + model = paddle.amp.decorate(models=model, + level=amp_level, + save_dtype='float32') scaler = paddle.amp.GradScaler(init_loss_scaling=32768) model, optimizer, scaler = group_sharded_parallel(model=model, @@ -102,13 +109,13 @@ def train_mlp(model, shard_level, use_pure_fp16, output_dir): img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(True, level='O2'): + with paddle.amp.auto_cast(True, level=amp_level): out = model(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) - if not use_pure_fp16: + if not use_multi_precision: avg_loss.backward() optimizer.step() else: @@ -134,19 +141,36 @@ def test_sharding_api(): # fp16 stage2_params = train_mlp(mlp1, shard_level="os_g", - use_pure_fp16=True, - output_dir=output_dir) + use_multi_precision=True, + output_dir=output_dir, + amp_level='O2') stage3_params = train_mlp(mlp2, shard_level="p_g_os", - use_pure_fp16=True, - output_dir=output_dir) + use_multi_precision=True, + output_dir=output_dir, + amp_level='O2') for i in range(len(stage3_params)): np.testing.assert_allclose(stage2_params[i].numpy(), stage3_params[i].numpy(), rtol=1e-4, atol=1e-3) - shutil.rmtree(output_dir) + + # AMP + mlp3, mlp4 = MLP(), MLP() + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + + stage2_params = train_mlp(mlp3, + shard_level="os_g", + use_multi_precision=True, + output_dir=output_dir, + amp_level='O1') + stage3_params = train_mlp(mlp4, + shard_level="p_g_os", + use_multi_precision=True, + output_dir=output_dir, + amp_level='O1') if __name__ == '__main__':