From 4e95c13617be97baab52cef777f96e304846de99 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 2 Sep 2021 19:35:58 +0800 Subject: [PATCH] feat(sgd): sgd supports nesterov momentum GitOrigin-RevId: 13eda179da9b79573f692916a02b2a51a4449a14 --- imperative/python/megengine/optimizer/sgd.py | 28 +++++++++++-------- .../python/test/integration/test_optimizer.py | 8 ++++-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index fe5efda23..a4d122820 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -16,7 +16,7 @@ from .optimizer import Optimizer class SGD(Optimizer): r"""Implements stochastic gradient descent. - + Nesterov momentum is based on the formula from `"On the importance of initialization and momentum in deep learning" `_ . @@ -25,6 +25,7 @@ class SGD(Optimizer): parameter groups. lr: learning rate. momentum: momentum factor. Default: 0.0 + nesterov: enables Nesterov momentum. Default: False weight_decay: weight decay (L2 penalty). Default: 0.0 """ @@ -33,6 +34,7 @@ class SGD(Optimizer): params: Union[Iterable[Parameter], dict], lr: float, momentum: float = 0.0, + nesterov: bool = False, weight_decay: float = 0.0, ): assert lr >= 0.0, "Invalid learning rate: {}".format(lr) @@ -40,9 +42,11 @@ class SGD(Optimizer): assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format( weight_decay ) + assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum" defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) super().__init__(params, defaults) + self.nesterov = nesterov self._disable_type_convert = True def _create_state(self, param_group): @@ -76,20 +80,22 @@ class SGD(Optimizer): grad = grad + param * _weight_decay if inplace_mode: - if momentum: + if momentum != 0.0: v = self._state[param]["momentum_buffer"] _inplace_add_(v, grad, alpha=_momentum, beta=c1) - _inplace_add_(param, v, alpha=c1, beta=_neg_lr) - else: - _inplace_add_(param, grad, alpha=c1, beta=_neg_lr) + if self.nesterov: + grad = grad + v * _momentum + else: + grad = v + _inplace_add_(param, grad, alpha=c1, beta=_neg_lr) continue - if momentum: + if momentum != 0.0: v = self._state[param]["momentum_buffer"] - # v = v * _momentum + grad v *= _momentum v += grad - - param -= _lr * v - else: - param -= _lr * grad + if self.nesterov: + grad = grad + v * _momentum + else: + grad = v + param -= _lr * grad diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index d5ca99a9b..577c09f73 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -124,6 +124,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): "case", [ {"momentum": 0.9, "lr": 0.01}, # SGD with momentum + {"momentum": 0.9, "lr": 0.01, "nesterov": True}, # with nesterov momentum {"lr": 0.01}, # simple SGD {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay ], @@ -144,9 +145,12 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode): grad = param.grad.numpy() if hasattr(self, "weight_decay") and self.weight_decay != 0.0: grad = grad + ori_params[param] * self.weight_decay - if hasattr(self, "momentum"): + if hasattr(self, "momentum") and self.momentum != 0.0: self.slots[param] = grad + self.slots[param] * self.momentum - delta = -self.lr * self.slots[param] + if hasattr(self, "nesterov") and self.nesterov: + delta = -self.lr * (grad + self.slots[param] * self.momentum) + else: + delta = -self.lr * self.slots[param] else: delta = -self.lr * grad np.testing.assert_almost_equal( -- GitLab