diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index fc5b93c6e25499a0ae50c19cacae4a9395520fe9..ea17f96f7a1ca0f99bc93eb046baebd60458614a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer): device="gpu", **kw): - # super().__init__(optim._learning_rate, params, kw) + super().__init__(optim._learning_rate, params, kw) # Segmentation information self._dtype_rank_params = OrderedDict( @@ -363,6 +363,10 @@ class ShardingOptimizerStage2(Optimizer): # Synchronize all the updated shards in between the ranks self._broadcast_params() + def minimize(self): + raise RuntimeError( + "optimizer.minimize() not support now, please use optimizer.step()") + def _clear_cache(self): self.__segment_params.clear() self._dtype_rank_params.clear() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 9d7bd937411882541d9cb1311c241d3e84316c90..484cd223949c6cb461ce760f6240300f46f06885 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -506,7 +506,13 @@ class ShardingStage3(nn.Layer): else: opt_step() + def _opt_minimize(self): + raise RuntimeError( + "optimizer.minimize() not support now, please use optimizer.step()" + ) + self._optim.step = MethodType(_opt_step, self._optim) + self._optim.minimize = MethodType(_opt_minimize, self._optim) def _redefine_opt_clear(self): clear_func = self._clear_gradients diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py index 6a9005b8ce6c1a430cc719bc962a4546c5b13c68..705831d50f171966a81635e97a149c6d9f4ba16d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py @@ -124,8 +124,17 @@ def train_mlp(): avg_loss.backward() oss_optimizer.step() - # oss_optimizer clear cache - oss_optimizer._clear_cache() + # oss_optimizer clear cache + oss_optimizer._clear_cache() + + # check optimizer.minimize() error + try: + oss_optimizer.minimize() + except: + print( + "====== Find sharding_stage2_optimizer.minimize() error ======" + ) + return if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index ddd31bc057f2e3f6eeeae571615f5e2991e6a8a2..9b218bf13027a0bac7e55e4b146351bafdedfb7a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -83,7 +83,8 @@ def train_mlp(model, accumulate_grad=False, batch_size=100, opt_group=False, - recompute=False): + recompute=False, + test_minimize=False): group = paddle.distributed.new_group([0, 1]) if opt_group: optimizer = optimizer_setting( @@ -113,6 +114,15 @@ def train_mlp(model, accumulate_grads=batch_size == 20, sync_comm=recompute) + # check optimizer.minimize() error + if test_minimize: + try: + optimizer.minimize() + except: + print( + "====== Find sharding_stage3_optimizer.minimize() error ======") + return + train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True) @@ -160,8 +170,8 @@ def train_mlp(model, def test_stage2_stage3(): - mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP( - ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() + mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP( + ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() state_dict = mlp.state_dict() mlp1.set_state_dict(state_dict) mlp2.set_state_dict(state_dict) @@ -171,6 +181,8 @@ def test_stage2_stage3(): mlp6.set_state_dict(state_dict) mlp7.set_state_dict(state_dict) mlp8.set_state_dict(state_dict) + mlp9.set_state_dict(state_dict) + # fp32 stage2_params = train_mlp( mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False) @@ -229,7 +241,14 @@ def test_stage2_stage3(): for i in range(len(stage3_params)): np.testing.assert_allclose( stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) - return + + # check optimizer.minimize() error + train_mlp( + mlp9, + sharding_stage=3, + use_pure_fp16=False, + opt_group=False, + test_minimize=True) if __name__ == '__main__':