未验证 提交 06a3e311 编写于 作者: 1 123malin 提交者: GitHub

test=develop, fix test_lookahead (#30677)

* test=develop, fix test_lookahead
上级 846ce406
...@@ -171,6 +171,7 @@ class LookAhead(Optimizer): ...@@ -171,6 +171,7 @@ class LookAhead(Optimizer):
""" """
self.inner_optimizer.step() self.inner_optimizer.step()
self._increment_global_var()
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if not param.trainable:
...@@ -188,7 +189,7 @@ class LookAhead(Optimizer): ...@@ -188,7 +189,7 @@ class LookAhead(Optimizer):
for p in parameters: for p in parameters:
self._add_accumulator(self._slow_str, p) self._add_accumulator(self._slow_str, p)
def _append_optimize_op(self, block, param_and_grad): def _increment_global_var(self):
if self._global_step_var is None: if self._global_step_var is None:
self._global_step_var = layers.create_global_var( self._global_step_var = layers.create_global_var(
name=unique_name.generate("lookahead_step"), name=unique_name.generate("lookahead_step"),
...@@ -203,6 +204,7 @@ class LookAhead(Optimizer): ...@@ -203,6 +204,7 @@ class LookAhead(Optimizer):
outputs={'Out': [self._global_step_var]}, outputs={'Out': [self._global_step_var]},
attrs={'step': 1.0}) attrs={'step': 1.0})
def _append_optimize_op(self, block, param_and_grad):
one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones') one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones')
zero_var = paddle.zeros( zero_var = paddle.zeros(
shape=[1], dtype='int32', name='lookahead_zeros') shape=[1], dtype='int32', name='lookahead_zeros')
...@@ -290,6 +292,8 @@ class LookAhead(Optimizer): ...@@ -290,6 +292,8 @@ class LookAhead(Optimizer):
parameters=parameters, parameters=parameters,
no_grad_set=no_grad_set) no_grad_set=no_grad_set)
self._increment_global_var()
_ = self._apply_optimize( _ = self._apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads) loss, startup_program=startup_program, params_grads=params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册