From 90f44c6f7f63dc72cebf377f7da00d2b00efef0b Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Fri, 28 Jan 2022 10:56:48 +0800 Subject: [PATCH] fix_stage2_minimize (#39285) --- .../sharding_optimizer_stage2.py | 6 ++++- .../meta_parallel/sharding/sharding_stage3.py | 6 +++++ .../dygraph_sharding_optimizer_stage2.py | 13 +++++++-- .../unittests/dygraph_sharding_stage3.py | 27 ++++++++++++++++--- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index fc5b93c6e25..ea17f96f7a1 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer): device="gpu", **kw): - # super().__init__(optim._learning_rate, params, kw) + super().__init__(optim._learning_rate, params, kw) # Segmentation information self._dtype_rank_params = OrderedDict( @@ -363,6 +363,10 @@ class ShardingOptimizerStage2(Optimizer): # Synchronize all the updated shards in between the ranks self._broadcast_params() + def minimize(self): + raise RuntimeError( + "optimizer.minimize() not support now, please use optimizer.step()") + def _clear_cache(self): self.__segment_params.clear() self._dtype_rank_params.clear() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 9d7bd937411..484cd223949 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -506,7 +506,13 @@ class ShardingStage3(nn.Layer): else: opt_step() + def _opt_minimize(self): + raise RuntimeError( + "optimizer.minimize() not support now, please use optimizer.step()" + ) + self._optim.step = MethodType(_opt_step, self._optim) + self._optim.minimize = MethodType(_opt_minimize, self._optim) def _redefine_opt_clear(self): clear_func = self._clear_gradients diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py index 6a9005b8ce6..705831d50f1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py @@ -124,8 +124,17 @@ def train_mlp(): avg_loss.backward() oss_optimizer.step() - # oss_optimizer clear cache - oss_optimizer._clear_cache() + # oss_optimizer clear cache + oss_optimizer._clear_cache() + + # check optimizer.minimize() error + try: + oss_optimizer.minimize() + except: + print( + "====== Find sharding_stage2_optimizer.minimize() error ======" + ) + return if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index ddd31bc057f..9b218bf1302 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -83,7 +83,8 @@ def train_mlp(model, accumulate_grad=False, batch_size=100, opt_group=False, - recompute=False): + recompute=False, + test_minimize=False): group = paddle.distributed.new_group([0, 1]) if opt_group: optimizer = optimizer_setting( @@ -113,6 +114,15 @@ def train_mlp(model, accumulate_grads=batch_size == 20, sync_comm=recompute) + # check optimizer.minimize() error + if test_minimize: + try: + optimizer.minimize() + except: + print( + "====== Find sharding_stage3_optimizer.minimize() error ======") + return + train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True) @@ -160,8 +170,8 @@ def train_mlp(model, def test_stage2_stage3(): - mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP( - ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() + mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP( + ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() state_dict = mlp.state_dict() mlp1.set_state_dict(state_dict) mlp2.set_state_dict(state_dict) @@ -171,6 +181,8 @@ def test_stage2_stage3(): mlp6.set_state_dict(state_dict) mlp7.set_state_dict(state_dict) mlp8.set_state_dict(state_dict) + mlp9.set_state_dict(state_dict) + # fp32 stage2_params = train_mlp( mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False) @@ -229,7 +241,14 @@ def test_stage2_stage3(): for i in range(len(stage3_params)): np.testing.assert_allclose( stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) - return + + # check optimizer.minimize() error + train_mlp( + mlp9, + sharding_stage=3, + use_pure_fp16=False, + opt_group=False, + test_minimize=True) if __name__ == '__main__': -- GitLab