diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index ba1b5222394e2d274107a2faafe7abfb536c8c1c..ffd24add50a4d01ad3dc28c8e31a15a6b14b9415 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -83,8 +83,14 @@ class ShardingOptimizerStage2(Optimizer): # Default information self._optim_defaults = kw self._optim = optim + assert hasattr(self._optim, "_master_weights" + ), "Must use optimizer with _master_weights attribute" self._local_params = params self._default_device = device + self._pfp16 = len( + list( + filter(lambda x: x.trainable and x.dtype == Type.fp16.value, + self._local_params))) > 0 assert group is not None, "Distributed communication group is must be gived" self.group = group @@ -98,6 +104,12 @@ class ShardingOptimizerStage2(Optimizer): # Update optimizer parameters and adjust parameter storage and use according to rank. self.update_opt_status() + def _generate_master_params(self, trainable_params): + for param in trainable_params: + if param.dtype == Type.fp16.value: + self._optim._master_weights[param.name] = paddle.cast( + param, Type.fp32.value) + def update_opt_status(self): """Update optimizer status and parameter storage information, and special functions to be developed. """ @@ -207,6 +219,8 @@ class ShardingOptimizerStage2(Optimizer): # Merge all the trainable params in a single InternalStorage trainable_params = list( filter(lambda x: x.trainable, params)) + if self._pfp16 and dst_rank == self.rank: + self._generate_master_params(trainable_params) if trainable_params: param_storage = ParamStorage( size=self.rank_buffer_size[dtype][dst_rank], diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index 8ac4a7e99c7d71f88f777d9cb61e835fd42a072f..329dc9eaa4e575cccb57b7aacf2214feab2f41f0 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -30,6 +30,7 @@ from paddle import nn import paddle.distributed as dist from ...utils.internal_storage import GradStorage +from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from .sharding_utils import Taskflow, Type @@ -70,6 +71,11 @@ class ShardingStage2(nn.Layer): self._layer = layer self._sharding_optimizers = [sharding_optimizer] if not isinstance( sharding_optimizer, list) else sharding_optimizer + assert all( + list( + map(lambda opt: isinstance(opt, ShardingOptimizerStage2), + self._sharding_optimizers)) + ), "Please use ShardingOptimizerStage2 optimizer" self._sync_buffers = sync_buffers self._auto_refresh_trainable = auto_refresh_trainable @@ -88,8 +94,7 @@ class ShardingStage2(nn.Layer): # Global statistical parameters self._all_params = list( - chain( - * [optim.local_params for optim in self._sharding_optimizers])) + chain(*[optim.local_params for optim in self._sharding_optimizers])) self._trainable_params = [] self._grad_reduced = [] self._trainable_param2rank = {} @@ -436,7 +441,7 @@ class ShardingStage2(nn.Layer): ._fill)) self._grad_storage_list = list( - chain(* [ + chain(*[ self._grad_storages[dtype].values() for dtype in self._grad_storages.keys() ])) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index bc62d18c860226bd5430a50a93e7e3af0aa68766..05008a3bc12f7ed6077342cf4e3723f2e7002b6c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -24,7 +24,6 @@ from paddle.fluid.dygraph.nn import Linear from paddle.distributed import fleet from paddle.fluid.dygraph import nn -from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 @@ -70,7 +69,7 @@ def reader_decorator(): return __reader__ -def optimizer_setting(model, use_pure_fp16, stage=1): +def optimizer_setting(model, use_pure_fp16): clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) optimizer = paddle.optimizer.AdamW( parameters=model.parameters(), @@ -87,20 +86,16 @@ def train_mlp(model, use_pure_fp16=False, all_test=False, accumulate_grad=False): - if sharding_stage == 1: + if sharding_stage == "dp": hcg = fleet.get_hybrid_communicate_group() group = hcg.get_check_parallel_group() else: group = paddle.distributed.new_group([0, 1]) - optimizer = optimizer_setting( - model=model, use_pure_fp16=use_pure_fp16, stage=sharding_stage) + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) if use_pure_fp16: - model, optimizer = paddle.amp.decorate( - models=model, - optimizers=optimizer, - level='O2', - save_dtype='float32') + model = paddle.amp.decorate( + models=model, level='O2', save_dtype='float32') if sharding_stage == 2: optimizer = ShardingOptimizerStage2( @@ -164,7 +159,7 @@ def train_mlp(model, return model.parameters() -def test_stage1_stage2(): +def test_dp_stage2(): mlp = MLP() state_dict = mlp.state_dict() mlp1 = MLP() @@ -175,11 +170,13 @@ def test_stage1_stage2(): mlp2.set_state_dict(state_dict) mlp3.set_state_dict(state_dict) mlp4.set_state_dict(state_dict) - stage1_params = train_mlp(mlp, sharding_stage=1, use_pure_fp16=False) - stage2_params = train_mlp(mlp, sharding_stage=2, use_pure_fp16=False) - for i in range(len(stage1_params)): - np.testing.assert_allclose( - stage1_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) + dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False) + stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False) + for i in range(len(dp_params)): + for j in range(len(stage2_params)): + if dp_params[i].name == stage2_params[j].name: + np.testing.assert_allclose( + dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6) stage2_params = train_mlp( mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True) @@ -201,4 +198,4 @@ def test_stage1_stage2(): if __name__ == '__main__': - test_stage1_stage2() + test_dp_stage2()