From 07086832a33b10536ef1ea56e57585e7eec93e93 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 7 Apr 2020 20:27:59 +0800 Subject: [PATCH] Fix problem use recompute and dgc same time, test=release/1.7 (#23022) --- python/paddle/fluid/optimizer.py | 3 +++ .../tests/unittests/test_dgc_optimizer.py | 23 ++++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 842b75401b..aa9294b30c 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3899,6 +3899,9 @@ class RecomputeOptimizer(Optimizer): parameter_list, no_grad_set, checkpoints=self._checkpoints) + # Note: since we can't use all_reduce_op now, + # dgc_op should be the last op of one grad. + self._optimizer._append_dgc_ops(params_grads) return params_grads def apply_optimize(self, loss, startup_program, params_grads): diff --git a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py index 521e498176..29050710c6 100644 --- a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py @@ -35,7 +35,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase): def check_dgc_momentum_optimizer(self, dims=[5, 10, 8], name="momentum", - regularization=None): + regularization=None, + use_recompute=False): init_program = framework.Program() program = framework.Program() block = program.global_block() @@ -72,6 +73,13 @@ class TestDGCMomentumOptimizer(unittest.TestCase): local_grad_clip_norm=1.0, num_trainers=2, regularization=regularization) + + if use_recompute: + dgc_momentum_optimizer = optimizer.RecomputeOptimizer( + dgc_momentum_optimizer) + dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators + dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str + mean_out = block.create_var( dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( @@ -112,8 +120,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase): self.assertAlmostEqual(op.attr('regular_coeff'), coeff) print("dgc regular_coeff=" + str(coeff)) - with open("test_dgc_optimizer_" + name + ".log", "w") as f: - program_to_code(program, fout=f) + # for local test debug + #with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f: + # program_to_code(program, fout=f) def test_momentum_without_dgc(self): self.check_dgc_momentum_optimizer( @@ -130,6 +139,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase): self.check_dgc_momentum_optimizer( dims=[16, 1024, 8], name="dgc_momentum") + def test_momentum_with_dgc_recompute(self): + # 16 * 1024 = 16384, use dgc momentum + self.check_dgc_momentum_optimizer( + dims=[16, 1024, 8], + name="dgc_momentum", + regularization=regularizer.L2Decay(1e-4), + use_recompute=True) + if __name__ == '__main__': unittest.main() -- GitLab