未验证 提交 06a3e311 编写于 作者: 1 123malin 提交者: GitHub

test=develop, fix test_lookahead (#30677)

* test=develop, fix test_lookahead
上级 846ce406
......@@ -171,6 +171,7 @@ class LookAhead(Optimizer):
"""
self.inner_optimizer.step()
self._increment_global_var()
params_grads = []
for param in self._parameter_list:
if not param.trainable:
......@@ -188,7 +189,7 @@ class LookAhead(Optimizer):
for p in parameters:
self._add_accumulator(self._slow_str, p)
def _append_optimize_op(self, block, param_and_grad):
def _increment_global_var(self):
if self._global_step_var is None:
self._global_step_var = layers.create_global_var(
name=unique_name.generate("lookahead_step"),
......@@ -203,6 +204,7 @@ class LookAhead(Optimizer):
outputs={'Out': [self._global_step_var]},
attrs={'step': 1.0})
def _append_optimize_op(self, block, param_and_grad):
one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones')
zero_var = paddle.zeros(
shape=[1], dtype='int32', name='lookahead_zeros')
......@@ -290,6 +292,8 @@ class LookAhead(Optimizer):
parameters=parameters,
no_grad_set=no_grad_set)
self._increment_global_var()
_ = self._apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册