提交 48f3cbdf 编写于 作者: M minqiyang

Polish code

test=develop
上级 35c89f38
...@@ -350,7 +350,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs): ...@@ -350,7 +350,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
following cosine decay strategy. following cosine decay strategy.
decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1) decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1)
Args: Args:
learning_rate(Variable|float): The initial learning rate. learning_rate(Variable|float): The initial learning rate.
step_each_epoch(int): the number of steps in an epoch. step_each_epoch(int): the number of steps in an epoch.
......
...@@ -94,13 +94,18 @@ class Optimizer(object): ...@@ -94,13 +94,18 @@ class Optimizer(object):
if imperative_base.enabled(): if imperative_base.enabled():
# create learning rate Variable # create learning rate Variable
if isinstance(self._learning_rate, float): if isinstance(self._learning_rate, float):
self._learning_rate_map[framework.default_main_program( lr = self._global_learning_rate()
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"), if isinstance(lr, framework.Variable):
shape=[1], return
value=float(self._learning_rate), else:
dtype='float32' if self._dtype is None else self._dtype, self._learning_rate_map[framework.default_main_program(
persistable=True) )] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
# get learning rate Variable from LearningRateDecay # get learning rate Variable from LearningRateDecay
elif isinstance(self._learning_rate, LearningRateDecay): elif isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
...@@ -114,11 +119,12 @@ class Optimizer(object): ...@@ -114,11 +119,12 @@ class Optimizer(object):
if isinstance(lr, framework.Variable): if isinstance(lr, framework.Variable):
return return
else:
if not isinstance(self._learning_rate, float): if not isinstance(self._learning_rate, float):
raise TypeError( raise TypeError(
"learning rate variable is create outside optimizer," "learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program") "can not create new learning rate variable for new program"
)
# create learning rate in the current main program # create learning rate in the current main program
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册