未验证 提交 f15a8aea 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

add adamw optimizer (#4824)

* init adamw optimizer

* fix adamw optimizer bug

* fix comment

* fix comment

* code format

* fix comment
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 f476d48d
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List, Dict, Callable, Union, Iterator, Tuple
from types import GeneratorType
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.nn.parameter import Parameter
from oneflow.python.nn.optimizer.optimizer import ParamGroup, Optimizer
@oneflow_export("optim.AdamW")
class AdamW(Optimizer):
r"""Implements AdamW algorithm.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
The optimizer of the Adam-weight-decay algorithm.
(More details please refer to `Adam-weight-decay <https://www.fast.ai/2018/07/02/adam-weight-decay/>`_).
So we use Adam-weight-decay algorithm to solve this problem.
the equation of parameters updating is:
.. math::
& V_t = \beta_1*V_{t-1} + (1-\beta_1)*grad
& S_t = \beta_2*S_{t-1} + (1-\beta_2)*{grad} \odot {grad}
& \hat{g} = learning\_rate*(\frac{{V_t}}{\sqrt{{S_t}}+\epsilon}+\lambda*param_{old})
& param_{new} = param_{old} - \hat{g}
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (In the equation is λ, default: 0)
scale (float, optional): the scale factor of loss (default: 1.0)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
"""
def __init__(
self,
parameters: Union[Iterator[Parameter], List[Dict]],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
scale: float = 1.0,
):
super().__init__()
assert lr >= 0.0, f"Invalid learning rate: {lr}"
assert eps >= 0.0, f"Invalid epsilon value: {eps}"
assert (
betas[0] >= 0.0 and betas[0] < 1.0
), f"Invalid beta parameter at index 0: {betas[0]}"
assert (
betas[1] >= 0.0 and betas[1] < 1.0
), f"Invalid beta parameter at index 1: {betas[1]}"
assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}"
assert scale > 0.0, f"Invalid scale factor: {scale}"
assert amsgrad is False, "Not support AMSGrad now!"
self._default_options["lr"] = lr
self._default_options["eps"] = eps
self._default_options["beta"] = betas
self._default_options["weight_decay"] = weight_decay
self._default_options["amsgrad"] = amsgrad
self._default_options["scale"] = scale
# Add parameters
if isinstance(parameters, GeneratorType):
self._param_groups.append(ParamGroup(parameters, self._default_options))
else: # List[Dict]
for param in parameters:
self._param_groups.append(ParamGroup(param, self._default_options))
for param_group in self._param_groups:
for param in param_group.parameters:
assert param.is_leaf, "parameters must be leaf tensor"
self._state[param] = dict()
self._state[param]["exp_avg"] = flow.tmp.zeros_like(param)
self._state[param]["exp_avg_sq"] = flow.tmp.zeros_like(param)
self._op = (
flow.builtin_op("adam_update")
.Input("model")
.Input("model_diff")
.Input("learning_rate")
.Input("m")
.Input("v")
.Attr("scale", self._default_options["scale"])
.Attr("l1", 0.0)
.Attr("l2", 0.0)
.Attr("beta1", self._default_options["beta"][0])
.Attr("beta2", self._default_options["beta"][1])
.Attr("epsilon", self._default_options["eps"])
.Attr("weight_decay", self._default_options["weight_decay"])
.Build()
)
def step(self, closure: Callable = None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for param_group in self._param_groups:
lr_tensor = flow.Tensor([param_group.options["lr"]])
for param in param_group.parameters:
if param.grad is None:
continue
m_tensor = self._state[param]["exp_avg"]
v_tensor = self._state[param]["exp_avg_sq"]
self._op(param, param.grad, lr_tensor, m_tensor, v_tensor)
self._state["step"] = self._state["step"] + 1
return loss
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow as flow
from test_util import GenArgList
from oneflow.python.nn.parameter import Parameter
def compare_with_numpy_adamw(
test_case, x_shape, scale, learning_rate, train_iters, weight_decay
):
# generate random number sequences
random_grad_seq = []
for _ in range(train_iters):
random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))
init_value = np.random.uniform(size=x_shape).astype(np.float32)
def train_by_oneflow():
x = Parameter(flow.Tensor(init_value))
param_list = list()
param_list.append(x)
adam = flow.optim.AdamW(
[{"param": param_list}],
lr=learning_rate,
scale=scale,
weight_decay=weight_decay,
)
def train_one_iter(grad):
grad_tensor = flow.Tensor(grad, requires_grad=False)
loss = x * grad_tensor
loss = flow.sum(x * grad_tensor)
loss.backward()
adam.step()
adam.zero_grad()
for i in range(train_iters):
train_one_iter(random_grad_seq[i])
return x
def train_by_numpy():
x = init_value
vt = np.zeros_like(x)
st = np.zeros_like(x)
beta1 = 0.9
beta2 = 0.999
def train_one_iter(grad):
grad = grad * scale
v = beta1 * vt + (1 - beta1) * grad
s = beta2 * st + (1 - beta2) * grad * grad
g = (
learning_rate / (np.sqrt(s) + 1e-8) * v
+ learning_rate * weight_decay * x
)
param = x - g
return param, v, s
for i in range(train_iters):
x, vt, st = train_one_iter(random_grad_seq[i])
return x
oneflow_res = train_by_oneflow().numpy()
numpy_res = train_by_numpy()
test_case.assertTrue(
np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4)
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestAdamW(flow.unittest.TestCase):
def test_adamw(test_case):
arg_dict = OrderedDict()
arg_dict["x_shape"] = [(10,)]
arg_dict["scale"] = [1.0, 0.9]
arg_dict["learning_rate"] = [1]
arg_dict["train_iters"] = [10]
arg_dict["weight_decay"] = [1e-3, 0.0]
for arg in GenArgList(arg_dict):
compare_with_numpy_adamw(test_case, *arg)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册