提交 de2d85b5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!199 Fix the bug of weight decay for `Adam` optimizer

Merge pull request !199 from seatea/adam-weight-decay
...@@ -166,7 +166,8 @@ class Adam(Optimizer): ...@@ -166,7 +166,8 @@ class Adam(Optimizer):
""" """
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0): use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Adam, self).__init__(learning_rate, params) super(Adam, self).__init__(learning_rate, params)
_check_param_value(beta1, beta2, eps, weight_decay) _check_param_value(beta1, beta2, eps, weight_decay)
validator.check_type("use_locking", use_locking, [bool]) validator.check_type("use_locking", use_locking, [bool])
...@@ -192,6 +193,7 @@ class Adam(Optimizer): ...@@ -192,6 +193,7 @@ class Adam(Optimizer):
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov) self.opt = P.Adam(use_locking, use_nesterov)
self.weight_decay = weight_decay * loss_scale self.weight_decay = weight_decay * loss_scale
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册