提交 28ff1cda 编写于 作者: Q qiaolongfei

create learning rate for each program

上级 50a6e7c5
...@@ -36,10 +36,15 @@ class Optimizer(object): ...@@ -36,10 +36,15 @@ class Optimizer(object):
""" """
def __init__(self, learning_rate, global_step=None, regularization=None): def __init__(self, learning_rate, global_step=None, regularization=None):
assert learning_rate is not None if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise ValueError("learning rate should be float or Variable")
self._global_step = global_step self._global_step = global_step
self.regularization = regularization self.regularization = regularization
self._global_learning_rate = learning_rate self._learning_rate = learning_rate
# each program should have a independent learning rate
# program -> Variable(learning_rate)
self._learning_rate_map = defaultdict(lambda: None)
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra variables associated with the parameters # allocate and manage extra variables associated with the parameters
# to train. These variables are called accumulators. # to train. These variables are called accumulators.
...@@ -48,26 +53,33 @@ class Optimizer(object): ...@@ -48,26 +53,33 @@ class Optimizer(object):
self.helper = None self.helper = None
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
if isinstance(self._global_learning_rate, float): lr = self.global_learning_rate()
self._global_learning_rate = layers.create_global_var(
name=unique_name.generate("learning_rate"), if isinstance(lr, framework.Variable):
shape=[1], return
value=float(self._global_learning_rate), else:
dtype='float32', if not isinstance(self._learning_rate, float):
persistable=True) raise ValueError(
"learning rate variable is create outside optimizer,"
if not isinstance(self._global_learning_rate, framework.Variable): "can not create new learning rate variable for new program")
raise ValueError("learning rate should be a Variable, "
"actual type is %s", # create learning rate in the current main program
type(self._global_learning_rate)) self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
@property name=unique_name.generate("learning_rate"),
def global_learning_rate(self): shape=[1],
value=float(self._learning_rate),
dtype='float32',
persistable=True)
def global_learning_rate(self, program=None):
""" """
get global decayed learning rate get global decayed learning rate
:return: :return:
""" """
return self._global_learning_rate if program is None:
program = framework.default_main_program()
return self._learning_rate_map[program]
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
""" append optimize operator to block and return all the added optimize_op """ append optimize operator to block and return all the added optimize_op
...@@ -78,7 +90,7 @@ class Optimizer(object): ...@@ -78,7 +90,7 @@ class Optimizer(object):
# create learning rate variable for every parameter # create learning rate variable for every parameter
param = param_and_grad[0] param = param_and_grad[0]
param_lr = param.optimize_attr['learning_rate'] param_lr = param.optimize_attr['learning_rate']
return self._global_learning_rate * param_lr return self.global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters """Create all accumulators needed by the parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册