From 3adad4852dae1b1a62da8e94599f11b019e4495b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 May 2020 14:37:36 +0800 Subject: [PATCH] feat(mge/optimizer): add optimizer adagrad GitOrigin-RevId: 60ff08c5c334c037695be9bda46037b58519afcd --- python_module/megengine/optimizer/__init__.py | 1 + python_module/megengine/optimizer/adagrad.py | 75 +++++++++++++++++++ .../test/unit/optimizer/test_optimizer.py | 71 ++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 python_module/megengine/optimizer/adagrad.py diff --git a/python_module/megengine/optimizer/__init__.py b/python_module/megengine/optimizer/__init__.py index 133eab29..328cfb9f 100644 --- a/python_module/megengine/optimizer/__init__.py +++ b/python_module/megengine/optimizer/__init__.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .adagrad import Adagrad from .adam import Adam from .lr_scheduler import LRScheduler from .multi_step_lr import MultiStepLR diff --git a/python_module/megengine/optimizer/adagrad.py b/python_module/megengine/optimizer/adagrad.py new file mode 100644 index 00000000..4683fa10 --- /dev/null +++ b/python_module/megengine/optimizer/adagrad.py @@ -0,0 +1,75 @@ +from typing import Iterable, Union + +import numpy as np + +from ..core import Buffer, Parameter +from ..functional import sqrt +from .internal import add_update_fastpath as add_update +from .optimizer import Optimizer + + +class Adagrad(Optimizer): + r"""Implements Adagrad algorithm. + + It has been proposed in `"Adaptive Subgradient Methods for Online Learning + and Stochastic Optimization" `_. + + :param params: iterable of parameters to optimize or dicts defining + parameter groups. + :param lr: coefficient that scale delta before it is applied + to the parameters (default: 1e-2). + :param lr_decay: learning rate decay (default: 0) + :param eps: term added to the denominator to improve + numerical stability (default: 1e-10). + :param weight_decay: weight decay (L2 penalty) (default: 0). + """ + + def __init__( + self, + params: Union[Iterable[Parameter], dict], + lr: float = 1e-2, + lr_decay: float = 0.0, + eps: float = 1e-10, + weight_decay: float = 0.0, + ): + assert lr >= 0.0, "Invalid learning rate: {}".format(lr) + assert lr_decay >= 0, "Invalid learning rate decay: {}".format(lr_decay) + assert eps >= 0.0, "Invalid epsilon value: {}".format(eps) + assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format( + weight_decay + ) + + defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + def _create_state(self, param_group): + for param in param_group["params"]: + self._add_state(param, "square_avg") + self._add_state(param, "step", initializer=0.0) + + def _updates(self, param_group): + lr = param_group["lr"] + lr_decay = param_group["lr_decay"] + weight_decay = param_group["weight_decay"] + eps = param_group["eps"] + + for param in param_group["params"]: + if not isinstance(param.grad, Buffer): + raise TypeError( + "grad must be a Buffer, maybe you forget to call backward()?" + ) + + if not param.requires_grad: + continue + + step = self._state[param]["step"] + step = add_update(step, 1) + grad = param.grad + if weight_decay != 0.0: + grad = add_update(grad, param, beta=weight_decay) + + square_avg = self._state[param]["square_avg"] + square_avg = add_update(square_avg, grad ** 2) + delta = grad / sqrt(square_avg + eps) + clr = lr / (1 + (step - 1) * lr_decay) + add_update(param, delta, beta=-clr) diff --git a/python_module/test/unit/optimizer/test_optimizer.py b/python_module/test/unit/optimizer/test_optimizer.py index 8f496c15..0d988d07 100644 --- a/python_module/test/unit/optimizer/test_optimizer.py +++ b/python_module/test/unit/optimizer/test_optimizer.py @@ -187,3 +187,74 @@ def test_adam(): for case in cases: _test_optimizer("Adam", case, CheckValue) _test_optimizer("Adam", case, CheckValue, update_lr=True) + + +def test_adam(): + class CheckValue: + def __init__(self, net, **kwarg): + self.m_slots = TensorDict() + self.v_slots = TensorDict() + for param in net.parameters(): + self.m_slots[param] = np.zeros(param.shape).astype(np.float32) + self.v_slots[param] = np.zeros(param.shape).astype(np.float32) + for k, v in kwarg.items(): + setattr(self, k, v) + + def __call__(self, ori_params, new_params, step): + for param in new_params: + grad = param.grad.numpy() + m = self.m_slots[param] + v = self.v_slots[param] + m *= self.betas[0] + m += (1 - self.betas[0]) * grad + v *= self.betas[1] + v += (1 - self.betas[1]) * grad * grad + delta = (m / (1 - self.betas[0] ** step)) / ( + np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps + ) + assertTensorClose(param.numpy(), ori_params[param] - self.lr * delta) + + cases = [ + {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01}, + { + "betas": (0.8, 0.9), + "eps": 1e-04, + "lr": 0.01, + "weight_decay": 0.1, + }, # with weight_decay + ] + for case in cases: + _test_optimizer("Adam", case, CheckValue) + _test_optimizer("Adam", case, CheckValue, update_lr=True) + + +def test_adagrad(): + class CheckValue: + def __init__(self, net, **kwarg): + self.s_slots = TensorDict() + for param in net.parameters(): + self.s_slots[param] = np.zeros(param.shape).astype(np.float32) + for k, v in kwarg.items(): + setattr(self, k, v) + + def __call__(self, ori_params, new_params, step): + for param in new_params: + grad = param.grad.numpy() + self.s_slots[param] += grad ** 2 + delta = grad / (self.s_slots[param] + self.eps) ** 0.5 + delta *= -(self.lr / (1 + (step - 1) * self.lr_decay)) + assertTensorClose(param.numpy(), ori_params[param] + delta) + + cases = [ + {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01}, + {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay + { + "lr": 0.01, + "eps": 1e-06, + "lr_decay": 0.01, + "weight_decay": 0.1, + }, # with weight_decay + ] + for case in cases: + _test_optimizer("Adagrad", case, CheckValue) + _test_optimizer("Adagrad", case, CheckValue, update_lr=True) -- GitLab