未验证 提交 9be39bb4 编写于 作者: W whs 提交者: GitHub

Enhence optimizer. (#13004)

上级 7ad39c40
......@@ -46,10 +46,12 @@ class Optimizer(object):
def __init__(self,
learning_rate,
regularization=None,
LARS_weight_decay=0.0):
LARS_weight_decay=0.0,
name=None):
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError("learning rate should be float or Variable")
self._name = name
self.regularization = regularization
self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss
......@@ -153,6 +155,8 @@ class Optimizer(object):
dtype: data type of the accumulator variable
fill_value: value to initialize the accumulator variable
"""
if self._name is not None:
name = self._name + "_" + name
if (name in self._accumulators and
param.name in self._accumulators[name]):
raise Exception("Accumulator {} already exists for parameter {}".
......@@ -181,6 +185,8 @@ class Optimizer(object):
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
if (name not in self._accumulators or
param.name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册