提交 48f3cbdf 编写于 作者: M minqiyang

Polish code

test=develop
上级 35c89f38
...@@ -94,6 +94,11 @@ class Optimizer(object): ...@@ -94,6 +94,11 @@ class Optimizer(object):
if imperative_base.enabled(): if imperative_base.enabled():
# create learning rate Variable # create learning rate Variable
if isinstance(self._learning_rate, float): if isinstance(self._learning_rate, float):
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
else:
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var( )] = layers.create_global_var(
name=unique_name.generate("learning_rate"), name=unique_name.generate("learning_rate"),
...@@ -114,11 +119,12 @@ class Optimizer(object): ...@@ -114,11 +119,12 @@ class Optimizer(object):
if isinstance(lr, framework.Variable): if isinstance(lr, framework.Variable):
return return
else:
if not isinstance(self._learning_rate, float): if not isinstance(self._learning_rate, float):
raise TypeError( raise TypeError(
"learning rate variable is create outside optimizer," "learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program") "can not create new learning rate variable for new program"
)
# create learning rate in the current main program # create learning rate in the current main program
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册