提交 ea9e62b8 编写于 作者: Q qiaolongfei

optimize code

上级 a636aa58
......@@ -38,13 +38,13 @@ class Optimizer(object):
def __init__(self, learning_rate, global_step=None, regularization=None):
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise ValueError("learning rate should be float or Variable")
raise TypeError("learning rate should be float or Variable")
self._global_step = global_step
self.regularization = regularization
self._learning_rate = learning_rate
# each program should have a independent learning rate
# program -> Variable(learning_rate)
self._learning_rate_map = defaultdict(lambda: None)
self._learning_rate_map = dict()
if isinstance(self._learning_rate, framework.Variable):
self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate
......@@ -62,7 +62,7 @@ class Optimizer(object):
return
else:
if not isinstance(self._learning_rate, float):
raise ValueError(
raise TypeError(
"learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program")
......@@ -82,7 +82,7 @@ class Optimizer(object):
"""
if program is None:
program = framework.default_main_program()
return self._learning_rate_map[program]
return self._learning_rate_map.get(program, None)
def _append_optimize_op(self, block, param_and_grad):
""" append optimize operator to block and return all the added optimize_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册