diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index 4ec43aa1e05ed448f50096ae840ff03775be4013..a2797adff251aea3535f86e5c423463d748c37b3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer): device="gpu", **kw): - super().__init__(optim._learning_rate, params, kw) + # super().__init__(optim._learning_rate, params, kw) # Segmentation information self._dtype_rank_params = OrderedDict( @@ -83,8 +83,6 @@ class ShardingOptimizerStage2(Optimizer): # Default information self._optim_defaults = kw self._optim = optim - self._ori_parameter_list = self._optim._parameter_list - self._ori_param_groups = self._optim._param_groups assert hasattr(self._optim, "_master_weights" ), "Must use optimizer with _master_weights attribute" @@ -336,24 +334,11 @@ class ShardingOptimizerStage2(Optimizer): if self.offload: params_list = [self.offload_params.buffer] - else: - # Synchronize optimizer parameters for the current rank - params_list = [] - for dtype in self.dtype_rank_params.keys(): - params_list.extend(self.dtype_rank_params[dtype][self.rank]) - params_name_list = list(map(lambda p: p.name, params_list)) - if not isinstance(self._optim._param_groups[0], dict): - self._optim._parameter_list = params_list - self._optim._param_groups = params_list - else: - for param_group in self._optim._param_groups: - p_group = [] - for param in param_group['params']: - if param.name in params_name_list: - p_group.append(params_list[params_name_list.index( - param.name)]) - param_group['params'] = p_group + #TODO(Baibaifan): Offload will support param_groups later + if not isinstance(self._optim._param_groups[0], dict): + self._optim._parameter_list = params_list + self._optim._param_groups = params_list # Run the optimizer of the current rank step if self.offload: @@ -371,10 +356,6 @@ class ShardingOptimizerStage2(Optimizer): # Synchronize all the updated shards in between the ranks self._broadcast_params() - # Return full parameters to optimizer parameters - self._optim._parameter_list = self._ori_parameter_list - self._optim._param_groups = self._ori_param_groups - def _clear_cache(self): self.__segment_params.clear() self._dtype_rank_params.clear() diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index e08b4db1e98def0257e5123147a3fadc1130f444..9206d744990008496e7af43d67e000f9d00f6dab 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -29,7 +29,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar seed = 2021 epoch = 2 -batch_size = 32 linear_size = 1000 strategy = fleet.DistributedStrategy() @@ -86,6 +85,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): def train_mlp(model, sharding_stage, + batch_size=100, use_pure_fp16=False, accumulate_grad=False, opt_group=False): @@ -103,16 +103,13 @@ def train_mlp(model, if sharding_stage == 2: optimizer = ShardingOptimizerStage2( params=model.parameters(), optim=optimizer, group=group) - if accumulate_grad: - model = ShardingStage2( - model, - optimizer, - group=group, - buffer_max_size=2**21, - accumulate_grads=accumulate_grad) - else: - model = ShardingStage2( - model, optimizer, group=group, buffer_max_size=2**21) + + model = ShardingStage2( + model, + optimizer, + group=group, + buffer_max_size=2**21, + accumulate_grads=batch_size == 20) else: optimizer = fleet.distributed_optimizer(optimizer) model = fleet.distributed_model(model) @@ -145,12 +142,13 @@ def train_mlp(model, avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss.backward() + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: optimizer.step() optimizer.clear_grad() - - if accumulate_grad and batch_id == 2: - return model.parameters() - return model.parameters() @@ -166,25 +164,22 @@ def test_dp_stage2(): mlp3.set_state_dict(state_dict) mlp4.set_state_dict(state_dict) dp_params = train_mlp( - mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=True) + mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False) stage2_params = train_mlp( - mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) + mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=False) for i in range(len(dp_params)): - for j in range(len(stage2_params)): - if dp_params[i].name == stage2_params[j].name: - np.testing.assert_allclose( - dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6) + np.testing.assert_allclose( + dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) stage2_params = train_mlp(mlp3, sharding_stage=2) stage2_accumulate_grad = train_mlp( - mlp4, sharding_stage=2, accumulate_grad=True) + mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True) for i in range(len(stage2_params)): - for j in range(len(stage2_accumulate_grad)): - if stage2_params[i].name == stage2_accumulate_grad[j].name: - np.testing.assert_allclose( - stage2_params[i].numpy(), - stage2_accumulate_grad[j].numpy(), - rtol=1e-6) + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage2_accumulate_grad[i].numpy(), + rtol=1e-5, + atol=1e-5) return