From 06a3e3114899e4d6a5c621d34d38c401e071d1f0 Mon Sep 17 00:00:00 2001 From: 123malin Date: Mon, 25 Jan 2021 14:17:58 +0800 Subject: [PATCH] test=develop, fix test_lookahead (#30677) * test=develop, fix test_lookahead --- python/paddle/incubate/optimizer/lookahead.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/incubate/optimizer/lookahead.py b/python/paddle/incubate/optimizer/lookahead.py index 3dca25c2bfb..f90d520a5df 100644 --- a/python/paddle/incubate/optimizer/lookahead.py +++ b/python/paddle/incubate/optimizer/lookahead.py @@ -171,6 +171,7 @@ class LookAhead(Optimizer): """ self.inner_optimizer.step() + self._increment_global_var() params_grads = [] for param in self._parameter_list: if not param.trainable: @@ -188,7 +189,7 @@ class LookAhead(Optimizer): for p in parameters: self._add_accumulator(self._slow_str, p) - def _append_optimize_op(self, block, param_and_grad): + def _increment_global_var(self): if self._global_step_var is None: self._global_step_var = layers.create_global_var( name=unique_name.generate("lookahead_step"), @@ -203,6 +204,7 @@ class LookAhead(Optimizer): outputs={'Out': [self._global_step_var]}, attrs={'step': 1.0}) + def _append_optimize_op(self, block, param_and_grad): one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones') zero_var = paddle.zeros( shape=[1], dtype='int32', name='lookahead_zeros') @@ -290,6 +292,8 @@ class LookAhead(Optimizer): parameters=parameters, no_grad_set=no_grad_set) + self._increment_global_var() + _ = self._apply_optimize( loss, startup_program=startup_program, params_grads=params_grads) -- GitLab