提交 3adad485 编写于 作者: M Megvii Engine Team

feat(mge/optimizer): add optimizer adagrad

GitOrigin-RevId: 60ff08c5c334c037695be9bda46037b58519afcd
上级 8be78b11
......@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .adagrad import Adagrad
from .adam import Adam
from .lr_scheduler import LRScheduler
from .multi_step_lr import MultiStepLR
......
from typing import Iterable, Union
import numpy as np
from ..core import Buffer, Parameter
from ..functional import sqrt
from .internal import add_update_fastpath as add_update
from .optimizer import Optimizer
class Adagrad(Optimizer):
r"""Implements Adagrad algorithm.
It has been proposed in `"Adaptive Subgradient Methods for Online Learning
and Stochastic Optimization" <http://jmlr.org/papers/v12/duchi11a.html>`_.
:param params: iterable of parameters to optimize or dicts defining
parameter groups.
:param lr: coefficient that scale delta before it is applied
to the parameters (default: 1e-2).
:param lr_decay: learning rate decay (default: 0)
:param eps: term added to the denominator to improve
numerical stability (default: 1e-10).
:param weight_decay: weight decay (L2 penalty) (default: 0).
"""
def __init__(
self,
params: Union[Iterable[Parameter], dict],
lr: float = 1e-2,
lr_decay: float = 0.0,
eps: float = 1e-10,
weight_decay: float = 0.0,
):
assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
assert lr_decay >= 0, "Invalid learning rate decay: {}".format(lr_decay)
assert eps >= 0.0, "Invalid epsilon value: {}".format(eps)
assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
weight_decay
)
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
def _create_state(self, param_group):
for param in param_group["params"]:
self._add_state(param, "square_avg")
self._add_state(param, "step", initializer=0.0)
def _updates(self, param_group):
lr = param_group["lr"]
lr_decay = param_group["lr_decay"]
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
for param in param_group["params"]:
if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
)
if not param.requires_grad:
continue
step = self._state[param]["step"]
step = add_update(step, 1)
grad = param.grad
if weight_decay != 0.0:
grad = add_update(grad, param, beta=weight_decay)
square_avg = self._state[param]["square_avg"]
square_avg = add_update(square_avg, grad ** 2)
delta = grad / sqrt(square_avg + eps)
clr = lr / (1 + (step - 1) * lr_decay)
add_update(param, delta, beta=-clr)
......@@ -187,3 +187,74 @@ def test_adam():
for case in cases:
_test_optimizer("Adam", case, CheckValue)
_test_optimizer("Adam", case, CheckValue, update_lr=True)
def test_adam():
class CheckValue:
def __init__(self, net, **kwarg):
self.m_slots = TensorDict()
self.v_slots = TensorDict()
for param in net.parameters():
self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
for k, v in kwarg.items():
setattr(self, k, v)
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
m = self.m_slots[param]
v = self.v_slots[param]
m *= self.betas[0]
m += (1 - self.betas[0]) * grad
v *= self.betas[1]
v += (1 - self.betas[1]) * grad * grad
delta = (m / (1 - self.betas[0] ** step)) / (
np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
)
assertTensorClose(param.numpy(), ori_params[param] - self.lr * delta)
cases = [
{"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
{
"betas": (0.8, 0.9),
"eps": 1e-04,
"lr": 0.01,
"weight_decay": 0.1,
}, # with weight_decay
]
for case in cases:
_test_optimizer("Adam", case, CheckValue)
_test_optimizer("Adam", case, CheckValue, update_lr=True)
def test_adagrad():
class CheckValue:
def __init__(self, net, **kwarg):
self.s_slots = TensorDict()
for param in net.parameters():
self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
for k, v in kwarg.items():
setattr(self, k, v)
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
self.s_slots[param] += grad ** 2
delta = grad / (self.s_slots[param] + self.eps) ** 0.5
delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
assertTensorClose(param.numpy(), ori_params[param] + delta)
cases = [
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay
{
"lr": 0.01,
"eps": 1e-06,
"lr_decay": 0.01,
"weight_decay": 0.1,
}, # with weight_decay
]
for case in cases:
_test_optimizer("Adagrad", case, CheckValue)
_test_optimizer("Adagrad", case, CheckValue, update_lr=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册