未验证 提交 4bbf4c71 编写于 作者: Q QiangX-man 提交者: GitHub

add ReduceLROnPlateau op (#6564)

* add ReduceLROnPlateau op

* add unit test case

* Mod review comments

* mod review comments

* mod review comments

* mod review comments

* mod review comment

* mod test case

* adjust test case
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 fb2947ba
......@@ -17,4 +17,5 @@ Optimizers
CosineAnnealingLR,
LambdaLR,
StepLR,
MultiStepLR
MultiStepLR,
ReduceLROnPlateau
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from math import inf
from .optimizer import Optimizer
class ReduceLROnPlateau(object):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This scheduler reads a metrics
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Args:
optimizer (Optimizer): Wrapped optimizer.
mode (str): One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor (float): Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience (int): Number of epochs with no improvement after
which learning rate will be reduced. For example, if
`patience = 2`, then we will ignore the first 2 epochs
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr (float or list): A scalar or a list of scalars. A
lower bound on the learning rate of all param groups
or each group respectively. Default: 0.
eps (float): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
For example:
.. code-block:: python
optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = flow.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
train(...)
val_loss = validate(...)
# Note that step should be called after validate()
scheduler.step(val_loss)
"""
def __init__(
self,
optimizer,
mode="min",
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode="rel",
cooldown=0,
min_lr=0,
eps=1e-8,
verbose=False,
):
if factor >= 1.0:
raise ValueError("Factor should be < 1.0.")
self.factor = factor
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(
"expected {} min_lrs, got {}".format(
len(optimizer.param_groups), len(min_lr)
)
)
self.min_lrs = list(min_lr)
else:
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.best = None
self.num_bad_steps = None
self.mode_worse = None # the worse value for the chosen mode
self.eps = eps
self.last_step = 0
self._init_is_better(
mode=mode, threshold=threshold, threshold_mode=threshold_mode
)
self._reset()
def step(self, metrics):
"""Step forward once.
Arguments:
metrics (float): a metrics quantity of Measuring the effect of model training.
"""
# convert `metrics` to float, in case it's a zero-dim Tensor
current = float(metrics)
self.last_step = self.last_step + 1
if self.is_better(current, self.best):
self.best = current
self.num_bad_steps = 0
else:
self.num_bad_steps += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_steps = 0 # ignore any bad epochs in cooldown
if self.num_bad_steps > self.patience:
self._reduce_lr(self.last_step)
self.cooldown_counter = self.cooldown
self.num_bad_steps = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
@property
def in_cooldown(self):
return self.cooldown_counter > 0
def is_better(self, a, best):
if self.mode == "min" and self.threshold_mode == "rel":
rel_epsilon = 1.0 - self.threshold
return a < best * rel_epsilon
elif self.mode == "min" and self.threshold_mode == "abs":
return a < best - self.threshold
elif self.mode == "max" and self.threshold_mode == "rel":
rel_epsilon = self.threshold + 1.0
return a > best * rel_epsilon
else: # mode == 'max' and epsilon_mode == 'abs':
return a > best + self.threshold
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
self._init_is_better(
mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
)
def _reduce_lr(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
if self.verbose:
print(
"Epoch {:5d}: reducing learning rate"
" of group {} to {:.4e}.".format(epoch, i, new_lr)
)
def _reset(self):
"""Resets num_bad_steps counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_steps = 0
def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {"min", "max"}:
raise ValueError("mode " + mode + " is unknown!")
if threshold_mode not in {"rel", "abs"}:
raise ValueError("threshold mode " + threshold_mode + " is unknown!")
if mode == "min":
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
......@@ -20,3 +20,4 @@ from oneflow.nn.optimizer.lr_scheduler import LrScheduler as _LRScheduler
from oneflow.nn.optimizer.step_lr import StepLR
from oneflow.nn.optimizer.multistep_lr import MultiStepLR
from oneflow.nn.optimizer.warm_up_lr import WarmUpLR
from oneflow.nn.optimizer.reduce_lr_on_plateau import ReduceLROnPlateau
......@@ -16,10 +16,60 @@ limitations under the License.
import math
import unittest
from collections import OrderedDict
from test_util import GenArgDict
import oneflow as flow
import oneflow.unittest
from oneflow.nn.parameter import Parameter
import torch
import random
def compare_with_troch_reduce_lr(
test_case, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps,
):
optimizer_flow = flow.optim.SGD(
[{"params": [Parameter(flow.Tensor([1.0]))]},],
lr=TestLrScheduler.base_lr,
momentum=0.9,
)
optimizer_torch = torch.optim.SGD(
[{"params": [torch.nn.Parameter(torch.Tensor([1.0]))]},],
lr=TestLrScheduler.base_lr,
momentum=0.9,
)
scheduler_flow = flow.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_flow,
mode,
factor,
patience,
threshold,
threshold_mode,
cooldown,
min_lr,
eps,
)
scheduler_troch = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_torch,
mode,
factor,
patience,
threshold,
threshold_mode,
cooldown,
min_lr,
eps,
)
val_loss = 0.1
for epoch in range(15):
val_loss += (random.random() - 0.5) / 10
scheduler_flow.step(val_loss)
scheduler_troch.step(val_loss)
for (lr1, lr2) in zip(scheduler_flow._last_lr, scheduler_troch._last_lr):
test_case.assertAlmostEqual(lr1, lr2, places=5)
@flow.unittest.skip_unless_1n1d()
......@@ -149,6 +199,19 @@ class TestLrScheduler(flow.unittest.TestCase):
for (lr1, lr2) in zip(lambda_lr.get_last_lr(), new_lrs):
test_case.assertAlmostEqual(lr1, lr2, places=5)
def test_reduce_lr_on_plateau(test_case):
arg_dict = OrderedDict()
arg_dict["mode"] = ["min", "max"]
arg_dict["factor"] = [0.1, 0.3]
arg_dict["patience"] = [2, 5]
arg_dict["threshold"] = [1e-3, 1e-5]
arg_dict["threshold_mode"] = ["rel", "abs"]
arg_dict["cooldown"] = [0, 1]
arg_dict["min_lr"] = [0, 1e-3]
arg_dict["eps"] = [1e-5, 1e-8]
for arg in GenArgDict(arg_dict):
compare_with_troch_reduce_lr(test_case, **arg)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册