diff --git a/docs/source/optim.rst b/docs/source/optim.rst index f65e3d250bf1d6e0368c0897faabf388f2d05cc8..e468f415236600e2af0ebc46fad88ea525f2e5e1 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -17,4 +17,5 @@ Optimizers CosineAnnealingLR, LambdaLR, StepLR, - MultiStepLR + MultiStepLR, + ReduceLROnPlateau diff --git a/python/oneflow/nn/optimizer/reduce_lr_on_plateau.py b/python/oneflow/nn/optimizer/reduce_lr_on_plateau.py new file mode 100644 index 0000000000000000000000000000000000000000..2686b2fb8167a112cfe7777a9f8ce7e9b9472f9b --- /dev/null +++ b/python/oneflow/nn/optimizer/reduce_lr_on_plateau.py @@ -0,0 +1,222 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from math import inf +from .optimizer import Optimizer + + +class ReduceLROnPlateau(object): + """Reduce learning rate when a metric has stopped improving. + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): Number of epochs with no improvement after + which learning rate will be reduced. For example, if + `patience = 2`, then we will ignore the first 2 epochs + with no improvement, and will only decrease the LR after the + 3rd epoch if the loss still hasn't improved then. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + For example: + + .. code-block:: python + + optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + scheduler = flow.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + for epoch in range(10): + train(...) + val_loss = validate(...) + # Note that step should be called after validate() + scheduler.step(val_loss) + """ + + def __init__( + self, + optimizer, + mode="min", + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode="rel", + cooldown=0, + min_lr=0, + eps=1e-8, + verbose=False, + ): + + if factor >= 1.0: + raise ValueError("Factor should be < 1.0.") + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + + if isinstance(min_lr, list) or isinstance(min_lr, tuple): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError( + "expected {} min_lrs, got {}".format( + len(optimizer.param_groups), len(min_lr) + ) + ) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + self.verbose = verbose + self.cooldown = cooldown + self.cooldown_counter = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.best = None + self.num_bad_steps = None + self.mode_worse = None # the worse value for the chosen mode + self.eps = eps + self.last_step = 0 + self._init_is_better( + mode=mode, threshold=threshold, threshold_mode=threshold_mode + ) + self._reset() + + def step(self, metrics): + """Step forward once. + + Arguments: + metrics (float): a metrics quantity of Measuring the effect of model training. + """ + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + self.last_step = self.last_step + 1 + + if self.is_better(current, self.best): + self.best = current + self.num_bad_steps = 0 + else: + self.num_bad_steps += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_steps = 0 # ignore any bad epochs in cooldown + + if self.num_bad_steps > self.patience: + self._reduce_lr(self.last_step) + self.cooldown_counter = self.cooldown + self.num_bad_steps = 0 + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + @property + def in_cooldown(self): + return self.cooldown_counter > 0 + + def is_better(self, a, best): + if self.mode == "min" and self.threshold_mode == "rel": + rel_epsilon = 1.0 - self.threshold + return a < best * rel_epsilon + + elif self.mode == "min" and self.threshold_mode == "abs": + return a < best - self.threshold + + elif self.mode == "max" and self.threshold_mode == "rel": + rel_epsilon = self.threshold + 1.0 + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + self._init_is_better( + mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode + ) + + def _reduce_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group["lr"]) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group["lr"] = new_lr + if self.verbose: + print( + "Epoch {:5d}: reducing learning rate" + " of group {} to {:.4e}.".format(epoch, i, new_lr) + ) + + def _reset(self): + """Resets num_bad_steps counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_steps = 0 + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {"min", "max"}: + raise ValueError("mode " + mode + " is unknown!") + if threshold_mode not in {"rel", "abs"}: + raise ValueError("threshold mode " + threshold_mode + " is unknown!") + + if mode == "min": + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode diff --git a/python/oneflow/optim/lr_scheduler.py b/python/oneflow/optim/lr_scheduler.py index f0067c2803c90bd007b2f2a22dd2eda69c7edf22..2f36167e1cb43fe638beedc2d89754bd37326e2d 100644 --- a/python/oneflow/optim/lr_scheduler.py +++ b/python/oneflow/optim/lr_scheduler.py @@ -20,3 +20,4 @@ from oneflow.nn.optimizer.lr_scheduler import LrScheduler as _LRScheduler from oneflow.nn.optimizer.step_lr import StepLR from oneflow.nn.optimizer.multistep_lr import MultiStepLR from oneflow.nn.optimizer.warm_up_lr import WarmUpLR +from oneflow.nn.optimizer.reduce_lr_on_plateau import ReduceLROnPlateau diff --git a/python/oneflow/test/modules/test_lr_scheduler.py b/python/oneflow/test/modules/test_lr_scheduler.py index caa65a5738f5d47ea16c0dcba707544a133f9d7b..848a87745206ad4b989126e445807e866aea646b 100644 --- a/python/oneflow/test/modules/test_lr_scheduler.py +++ b/python/oneflow/test/modules/test_lr_scheduler.py @@ -16,10 +16,60 @@ limitations under the License. import math import unittest +from collections import OrderedDict +from test_util import GenArgDict import oneflow as flow import oneflow.unittest from oneflow.nn.parameter import Parameter +import torch +import random + + +def compare_with_troch_reduce_lr( + test_case, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps, +): + optimizer_flow = flow.optim.SGD( + [{"params": [Parameter(flow.Tensor([1.0]))]},], + lr=TestLrScheduler.base_lr, + momentum=0.9, + ) + + optimizer_torch = torch.optim.SGD( + [{"params": [torch.nn.Parameter(torch.Tensor([1.0]))]},], + lr=TestLrScheduler.base_lr, + momentum=0.9, + ) + + scheduler_flow = flow.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_flow, + mode, + factor, + patience, + threshold, + threshold_mode, + cooldown, + min_lr, + eps, + ) + scheduler_troch = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_torch, + mode, + factor, + patience, + threshold, + threshold_mode, + cooldown, + min_lr, + eps, + ) + val_loss = 0.1 + for epoch in range(15): + val_loss += (random.random() - 0.5) / 10 + scheduler_flow.step(val_loss) + scheduler_troch.step(val_loss) + for (lr1, lr2) in zip(scheduler_flow._last_lr, scheduler_troch._last_lr): + test_case.assertAlmostEqual(lr1, lr2, places=5) @flow.unittest.skip_unless_1n1d() @@ -149,6 +199,19 @@ class TestLrScheduler(flow.unittest.TestCase): for (lr1, lr2) in zip(lambda_lr.get_last_lr(), new_lrs): test_case.assertAlmostEqual(lr1, lr2, places=5) + def test_reduce_lr_on_plateau(test_case): + arg_dict = OrderedDict() + arg_dict["mode"] = ["min", "max"] + arg_dict["factor"] = [0.1, 0.3] + arg_dict["patience"] = [2, 5] + arg_dict["threshold"] = [1e-3, 1e-5] + arg_dict["threshold_mode"] = ["rel", "abs"] + arg_dict["cooldown"] = [0, 1] + arg_dict["min_lr"] = [0, 1e-3] + arg_dict["eps"] = [1e-5, 1e-8] + for arg in GenArgDict(arg_dict): + compare_with_troch_reduce_lr(test_case, **arg) + if __name__ == "__main__": unittest.main()