未验证 提交 e97905c5 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of momentum (#30881)

上级 4b2d52a0
...@@ -104,7 +104,7 @@ class Momentum(Optimizer): ...@@ -104,7 +104,7 @@ class Momentum(Optimizer):
raise ValueError("learning_rate is not set") raise ValueError("learning_rate is not set")
if momentum is None: if momentum is None:
raise ValueError("momentum is not set") raise ValueError("momentum is not set")
predicate = lambda regular: isinstance(regular, L2DecayRegularizer) predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float))
py_regular = None if predicate(weight_decay) else weight_decay py_regular = None if predicate(weight_decay) else weight_decay
super(Momentum, self).__init__( super(Momentum, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
...@@ -120,6 +120,9 @@ class Momentum(Optimizer): ...@@ -120,6 +120,9 @@ class Momentum(Optimizer):
if (isinstance(weight_decay, L2DecayRegularizer)): if (isinstance(weight_decay, L2DecayRegularizer)):
self._regularization_method = "l2_decay" self._regularization_method = "l2_decay"
self._regularization_coeff = weight_decay._regularization_coeff self._regularization_coeff = weight_decay._regularization_coeff
if (isinstance(weight_decay, float)):
self._regularization_method = "l2_decay"
self._regularization_coeff = weight_decay
self._multi_precision = multi_precision self._multi_precision = multi_precision
self._rescale_grad = rescale_grad self._rescale_grad = rescale_grad
self._master_weights = {} self._master_weights = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册