提交 eef77fdd 编写于 作者: T tangwei12

lookup table bug fix about lr, test=develop

上级 fa2ab334
...@@ -1522,13 +1522,17 @@ class Program(object): ...@@ -1522,13 +1522,17 @@ class Program(object):
>>> with program.lr_schedule_guard(): >>> with program.lr_schedule_guard():
>>> lr = lr * decay >>> lr = lr * decay
""" """
tmp_role = self._current_role
tmp_var = self._op_role_var
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.LRSched self._current_role = OpRole.LRSched
# TODO(typhoonzero): how to set target learning rate var # TODO(typhoonzero): how to set target learning rate var
self._op_role_var = [] self._op_role_var = []
yield yield
self._op_role_var = [] self._op_role_var = tmp_var
self._current_role = OpRole.Forward self._current_role = tmp_role
def __str__(self): def __str__(self):
""" """
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
import re import re
from collections import defaultdict from collections import defaultdict
from paddle.fluid.framework import Program, Variable, name_scope from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from . import framework from . import framework
from . import layers from . import layers
from .backward import append_backward from .backward import append_backward
...@@ -111,6 +111,7 @@ class Optimizer(object): ...@@ -111,6 +111,7 @@ class Optimizer(object):
if param_lr == 1.0: if param_lr == 1.0:
return self._global_learning_rate() return self._global_learning_rate()
else: else:
with default_main_program()._lr_schedule_guard():
return self._global_learning_rate() * param_lr return self._global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册