diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 96488bfc96a1d6fd8e65aa658dfd998fa0dafb75..2b7275b180f7ca1ff13736ff55b581d23f9c278b 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3926,6 +3926,9 @@ class RecomputeOptimizer(Optimizer): parameter_list, no_grad_set, checkpoints=self._checkpoints) + # Note: since we can't use all_reduce_op now, + # dgc_op should be the last op of one grad. + self._optimizer._append_dgc_ops(params_grads) return params_grads def apply_optimize(self, loss, startup_program, params_grads): diff --git a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py index 521e4981769b9a1d2b77a196ad0b9db7f8a0c5d8..29050710c62fa9af08f7e5e0a2a18588a241cdba 100644 --- a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py @@ -35,7 +35,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase): def check_dgc_momentum_optimizer(self, dims=[5, 10, 8], name="momentum", - regularization=None): + regularization=None, + use_recompute=False): init_program = framework.Program() program = framework.Program() block = program.global_block() @@ -72,6 +73,13 @@ class TestDGCMomentumOptimizer(unittest.TestCase): local_grad_clip_norm=1.0, num_trainers=2, regularization=regularization) + + if use_recompute: + dgc_momentum_optimizer = optimizer.RecomputeOptimizer( + dgc_momentum_optimizer) + dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators + dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str + mean_out = block.create_var( dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( @@ -112,8 +120,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase): self.assertAlmostEqual(op.attr('regular_coeff'), coeff) print("dgc regular_coeff=" + str(coeff)) - with open("test_dgc_optimizer_" + name + ".log", "w") as f: - program_to_code(program, fout=f) + # for local test debug + #with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f: + # program_to_code(program, fout=f) def test_momentum_without_dgc(self): self.check_dgc_momentum_optimizer( @@ -130,6 +139,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase): self.check_dgc_momentum_optimizer( dims=[16, 1024, 8], name="dgc_momentum") + def test_momentum_with_dgc_recompute(self): + # 16 * 1024 = 16384, use dgc momentum + self.check_dgc_momentum_optimizer( + dims=[16, 1024, 8], + name="dgc_momentum", + regularization=regularizer.L2Decay(1e-4), + use_recompute=True) + if __name__ == '__main__': unittest.main()