提交 4e95c136 编写于 作者: M Megvii Engine Team

feat(sgd): sgd supports nesterov momentum

GitOrigin-RevId: 13eda179da9b79573f692916a02b2a51a4449a14
上级 ff431e72
...@@ -16,7 +16,7 @@ from .optimizer import Optimizer ...@@ -16,7 +16,7 @@ from .optimizer import Optimizer
class SGD(Optimizer): class SGD(Optimizer):
r"""Implements stochastic gradient descent. r"""Implements stochastic gradient descent.
Nesterov momentum is based on the formula from Nesterov momentum is based on the formula from
`"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ . `"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .
...@@ -25,6 +25,7 @@ class SGD(Optimizer): ...@@ -25,6 +25,7 @@ class SGD(Optimizer):
parameter groups. parameter groups.
lr: learning rate. lr: learning rate.
momentum: momentum factor. Default: 0.0 momentum: momentum factor. Default: 0.0
nesterov: enables Nesterov momentum. Default: False
weight_decay: weight decay (L2 penalty). Default: 0.0 weight_decay: weight decay (L2 penalty). Default: 0.0
""" """
...@@ -33,6 +34,7 @@ class SGD(Optimizer): ...@@ -33,6 +34,7 @@ class SGD(Optimizer):
params: Union[Iterable[Parameter], dict], params: Union[Iterable[Parameter], dict],
lr: float, lr: float,
momentum: float = 0.0, momentum: float = 0.0,
nesterov: bool = False,
weight_decay: float = 0.0, weight_decay: float = 0.0,
): ):
assert lr >= 0.0, "Invalid learning rate: {}".format(lr) assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
...@@ -40,9 +42,11 @@ class SGD(Optimizer): ...@@ -40,9 +42,11 @@ class SGD(Optimizer):
assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format( assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
weight_decay weight_decay
) )
assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum"
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults) super().__init__(params, defaults)
self.nesterov = nesterov
self._disable_type_convert = True self._disable_type_convert = True
def _create_state(self, param_group): def _create_state(self, param_group):
...@@ -76,20 +80,22 @@ class SGD(Optimizer): ...@@ -76,20 +80,22 @@ class SGD(Optimizer):
grad = grad + param * _weight_decay grad = grad + param * _weight_decay
if inplace_mode: if inplace_mode:
if momentum: if momentum != 0.0:
v = self._state[param]["momentum_buffer"] v = self._state[param]["momentum_buffer"]
_inplace_add_(v, grad, alpha=_momentum, beta=c1) _inplace_add_(v, grad, alpha=_momentum, beta=c1)
_inplace_add_(param, v, alpha=c1, beta=_neg_lr) if self.nesterov:
else: grad = grad + v * _momentum
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr) else:
grad = v
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
continue continue
if momentum: if momentum != 0.0:
v = self._state[param]["momentum_buffer"] v = self._state[param]["momentum_buffer"]
# v = v * _momentum + grad
v *= _momentum v *= _momentum
v += grad v += grad
if self.nesterov:
param -= _lr * v grad = grad + v * _momentum
else: else:
param -= _lr * grad grad = v
param -= _lr * grad
...@@ -124,6 +124,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -124,6 +124,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
"case", "case",
[ [
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum {"momentum": 0.9, "lr": 0.01}, # SGD with momentum
{"momentum": 0.9, "lr": 0.01, "nesterov": True}, # with nesterov momentum
{"lr": 0.01}, # simple SGD {"lr": 0.01}, # simple SGD
{"weight_decay": 0.1, "lr": 0.01}, # with weight_decay {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
], ],
...@@ -144,9 +145,12 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode): ...@@ -144,9 +145,12 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode):
grad = param.grad.numpy() grad = param.grad.numpy()
if hasattr(self, "weight_decay") and self.weight_decay != 0.0: if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
grad = grad + ori_params[param] * self.weight_decay grad = grad + ori_params[param] * self.weight_decay
if hasattr(self, "momentum"): if hasattr(self, "momentum") and self.momentum != 0.0:
self.slots[param] = grad + self.slots[param] * self.momentum self.slots[param] = grad + self.slots[param] * self.momentum
delta = -self.lr * self.slots[param] if hasattr(self, "nesterov") and self.nesterov:
delta = -self.lr * (grad + self.slots[param] * self.momentum)
else:
delta = -self.lr * self.slots[param]
else: else:
delta = -self.lr * grad delta = -self.lr * grad
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册