提交 ea9e62b8 编写于 作者: Q qiaolongfei

optimize code

上级 a636aa58
...@@ -38,13 +38,13 @@ class Optimizer(object): ...@@ -38,13 +38,13 @@ class Optimizer(object):
def __init__(self, learning_rate, global_step=None, regularization=None): def __init__(self, learning_rate, global_step=None, regularization=None):
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable): not isinstance(learning_rate, framework.Variable):
raise ValueError("learning rate should be float or Variable") raise TypeError("learning rate should be float or Variable")
self._global_step = global_step self._global_step = global_step
self.regularization = regularization self.regularization = regularization
self._learning_rate = learning_rate self._learning_rate = learning_rate
# each program should have a independent learning rate # each program should have a independent learning rate
# program -> Variable(learning_rate) # program -> Variable(learning_rate)
self._learning_rate_map = defaultdict(lambda: None) self._learning_rate_map = dict()
if isinstance(self._learning_rate, framework.Variable): if isinstance(self._learning_rate, framework.Variable):
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate )] = self._learning_rate
...@@ -62,7 +62,7 @@ class Optimizer(object): ...@@ -62,7 +62,7 @@ class Optimizer(object):
return return
else: else:
if not isinstance(self._learning_rate, float): if not isinstance(self._learning_rate, float):
raise ValueError( raise TypeError(
"learning rate variable is create outside optimizer," "learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program") "can not create new learning rate variable for new program")
...@@ -82,7 +82,7 @@ class Optimizer(object): ...@@ -82,7 +82,7 @@ class Optimizer(object):
""" """
if program is None: if program is None:
program = framework.default_main_program() program = framework.default_main_program()
return self._learning_rate_map[program] return self._learning_rate_map.get(program, None)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
""" append optimize operator to block and return all the added optimize_op """ append optimize operator to block and return all the added optimize_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册